From aee456ef18c23bd2227f9e42d04caf35b17b4512 Mon Sep 17 00:00:00 2001 From: jif-oai Date: Tue, 10 Feb 2026 14:08:48 +0000 Subject: [PATCH 01/16] Document app-server backpressure --- codex-rs/app-server/README.md | 6 + codex-rs/app-server/src/error_code.rs | 1 + codex-rs/app-server/src/lib.rs | 106 +++++++++++++++-- codex-rs/app-server/src/transport.rs | 158 +++++++++++++++++++++++--- 4 files changed, 248 insertions(+), 23 deletions(-) diff --git a/codex-rs/app-server/README.md b/codex-rs/app-server/README.md index 66d4a501ec5..c40693c601a 100644 --- a/codex-rs/app-server/README.md +++ b/codex-rs/app-server/README.md @@ -28,6 +28,12 @@ Supported transports: Websocket transport is currently experimental and unsupported. Do not rely on it for production workloads. +Backpressure behavior: + +- The server uses bounded queues between transport ingress, request processing, and outbound writes. +- When request ingress is saturated, new requests are rejected with a JSON-RPC error code `-32001` and message `"Server overloaded; retry later."`. +- Clients should treat this as retryable and use exponential backoff with jitter. + ## Message Schema Currently, you can dump a TypeScript version of the schema using `codex app-server generate-ts`, or a JSON Schema bundle via `codex app-server generate-json-schema`. Each output is specific to the version of Codex you used to run the command, so the generated artifacts are guaranteed to match that version. diff --git a/codex-rs/app-server/src/error_code.rs b/codex-rs/app-server/src/error_code.rs index 1ffd889d404..ca93b2f2d33 100644 --- a/codex-rs/app-server/src/error_code.rs +++ b/codex-rs/app-server/src/error_code.rs @@ -1,2 +1,3 @@ pub(crate) const INVALID_REQUEST_ERROR_CODE: i64 = -32600; pub(crate) const INTERNAL_ERROR_CODE: i64 = -32603; +pub(crate) const OVERLOADED_ERROR_CODE: i64 = -32001; diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index ad049ad3055..0de4ddbf8e6 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -21,6 +21,7 @@ use crate::outgoing_message::OutgoingEnvelope; use crate::outgoing_message::OutgoingMessageSender; use crate::transport::CHANNEL_CAPACITY; use crate::transport::ConnectionState; +use crate::transport::OutboundConnectionState; use crate::transport::TransportEvent; use crate::transport::has_initialized_connections; use crate::transport::route_outgoing_envelope; @@ -37,6 +38,7 @@ use codex_core::config_loader::ConfigLoadError; use codex_core::config_loader::TextRange as CoreTextRange; use codex_feedback::CodexFeedback; use tokio::sync::mpsc; +use tokio::sync::oneshot; use tokio::task::JoinHandle; use toml::Value as TomlValue; use tracing::error; @@ -61,6 +63,29 @@ mod transport; pub use crate::transport::AppServerTransport; +/// Control-plane messages from the processor/transport side to the outbound router task. +/// +/// `run_main_with_transport` now uses two loops/tasks: +/// - processor loop: handles incoming JSON-RPC and request dispatch +/// - outbound loop: performs potentially slow writes to per-connection writers +/// +/// `OutboundControlEvent` keeps those loops coordinated without sharing mutable +/// connection state directly. In particular, the outbound loop needs to know: +/// - when a connection opens/closes so it can route messages correctly +/// - when a connection becomes initialized so broadcast semantics remain unchanged +enum OutboundControlEvent { + /// Register a new writer for an opened connection. + Opened { + connection_id: ConnectionId, + writer: mpsc::Sender, + ready: oneshot::Sender<()>, + }, + /// Remove state for a closed/disconnected connection. + Closed { connection_id: ConnectionId }, + /// Mark the connection as initialized, enabling broadcast delivery. + Initialized { connection_id: ConnectionId }, +} + fn config_warning_from_error( summary: impl Into, err: &std::io::Error, @@ -197,6 +222,8 @@ pub async fn run_main_with_transport( let (transport_event_tx, mut transport_event_rx) = mpsc::channel::(CHANNEL_CAPACITY); let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY); + let (outbound_control_tx, mut outbound_control_rx) = + mpsc::channel::(CHANNEL_CAPACITY); let mut stdio_handles = Vec::>::new(); let mut websocket_accept_handle = None; @@ -336,8 +363,47 @@ pub async fn run_main_with_transport( } } + let outbound_handle = tokio::spawn(async move { + let mut outbound_connections = HashMap::::new(); + loop { + tokio::select! { + envelope = outgoing_rx.recv() => { + let Some(envelope) = envelope else { + break; + }; + route_outgoing_envelope(&mut outbound_connections, envelope).await; + } + event = outbound_control_rx.recv() => { + let Some(event) = event else { + break; + }; + match event { + OutboundControlEvent::Opened { + connection_id, + writer, + ready, + } => { + outbound_connections.insert(connection_id, OutboundConnectionState::new(writer)); + let _ = ready.send(()); + } + OutboundControlEvent::Closed { connection_id } => { + outbound_connections.remove(&connection_id); + } + OutboundControlEvent::Initialized { connection_id } => { + if let Some(connection_state) = outbound_connections.get_mut(&connection_id) { + connection_state.initialized = true; + } + } + } + } + } + } + info!("outbound router task exited (channel closed)"); + }); + let processor_handle = tokio::spawn({ let outgoing_message_sender = Arc::new(OutgoingMessageSender::new(outgoing_tx)); + let outbound_control_tx = outbound_control_tx; let cli_overrides: Vec<(String, TomlValue)> = cli_kv_overrides.clone(); let loader_overrides = loader_overrides_for_config_api; let mut processor = MessageProcessor::new(MessageProcessorArgs { @@ -362,9 +428,31 @@ pub async fn run_main_with_transport( }; match event { TransportEvent::ConnectionOpened { connection_id, writer } => { - connections.insert(connection_id, ConnectionState::new(writer)); + let (ready_tx, ready_rx) = oneshot::channel(); + if outbound_control_tx + .send(OutboundControlEvent::Opened { + connection_id, + writer: writer.clone(), + ready: ready_tx, + }) + .await + .is_err() + { + break; + } + if ready_rx.await.is_err() { + break; + } + connections.insert(connection_id, ConnectionState::new()); } TransportEvent::ConnectionClosed { connection_id } => { + if outbound_control_tx + .send(OutboundControlEvent::Closed { connection_id }) + .await + .is_err() + { + break; + } connections.remove(&connection_id); if shutdown_when_no_connections && connections.is_empty() { break; @@ -377,6 +465,7 @@ pub async fn run_main_with_transport( warn!("dropping request from unknown connection: {:?}", connection_id); continue; }; + let was_initialized = connection_state.session.initialized; processor .process_request( connection_id, @@ -384,6 +473,14 @@ pub async fn run_main_with_transport( &mut connection_state.session, ) .await; + if !was_initialized && connection_state.session.initialized { + let send_result = outbound_control_tx + .send(OutboundControlEvent::Initialized { connection_id }) + .await; + if send_result.is_err() { + break; + } + } } JSONRPCMessage::Response(response) => { processor.process_response(response).await; @@ -398,12 +495,6 @@ pub async fn run_main_with_transport( } } } - envelope = outgoing_rx.recv() => { - let Some(envelope) = envelope else { - break; - }; - route_outgoing_envelope(&mut connections, envelope).await; - } created = thread_created_rx.recv(), if listen_for_threads => { match created { Ok(thread_id) => { @@ -433,6 +524,7 @@ pub async fn run_main_with_transport( drop(transport_event_tx); let _ = processor_handle.await; + let _ = outbound_handle.await; if let Some(handle) = websocket_accept_handle { handle.abort(); diff --git a/codex-rs/app-server/src/transport.rs b/codex-rs/app-server/src/transport.rs index 39fd13212cf..80a7c3424ec 100644 --- a/codex-rs/app-server/src/transport.rs +++ b/codex-rs/app-server/src/transport.rs @@ -1,8 +1,12 @@ +use crate::error_code::OVERLOADED_ERROR_CODE; use crate::message_processor::ConnectionSessionState; use crate::outgoing_message::ConnectionId; use crate::outgoing_message::OutgoingEnvelope; +use crate::outgoing_message::OutgoingError; use crate::outgoing_message::OutgoingMessage; +use codex_app_server_protocol::JSONRPCErrorError; use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCRequest; use futures::SinkExt; use futures::StreamExt; use owo_colors::OwoColorize; @@ -140,15 +144,27 @@ pub(crate) enum TransportEvent { } pub(crate) struct ConnectionState { - pub(crate) writer: mpsc::Sender, pub(crate) session: ConnectionSessionState, } impl ConnectionState { + pub(crate) fn new() -> Self { + Self { + session: ConnectionSessionState::default(), + } + } +} + +pub(crate) struct OutboundConnectionState { + pub(crate) writer: mpsc::Sender, + pub(crate) initialized: bool, +} + +impl OutboundConnectionState { pub(crate) fn new(writer: mpsc::Sender) -> Self { Self { writer, - session: ConnectionSessionState::default(), + initialized: false, } } } @@ -159,6 +175,7 @@ pub(crate) async fn start_stdio_connection( ) -> IoResult<()> { let connection_id = ConnectionId(0); let (writer_tx, mut writer_rx) = mpsc::channel::(CHANNEL_CAPACITY); + let writer_tx_for_reader = writer_tx.clone(); transport_event_tx .send(TransportEvent::ConnectionOpened { connection_id, @@ -178,11 +195,10 @@ pub(crate) async fn start_stdio_connection( Ok(Some(line)) => { if !forward_incoming_message( &transport_event_tx_for_reader, + &writer_tx_for_reader, connection_id, &line, - ) - .await - { + ) { break; } } @@ -267,6 +283,7 @@ async fn run_websocket_connection( }; let (writer_tx, mut writer_rx) = mpsc::channel::(CHANNEL_CAPACITY); + let writer_tx_for_reader = writer_tx.clone(); if transport_event_tx .send(TransportEvent::ConnectionOpened { connection_id, @@ -295,7 +312,12 @@ async fn run_websocket_connection( incoming_message = websocket_reader.next() => { match incoming_message { Some(Ok(WebSocketMessage::Text(text))) => { - if !forward_incoming_message(&transport_event_tx, connection_id, &text).await { + if !forward_incoming_message( + &transport_event_tx, + &writer_tx_for_reader, + connection_id, + &text, + ) { break; } } @@ -324,19 +346,14 @@ async fn run_websocket_connection( .await; } -async fn forward_incoming_message( +fn forward_incoming_message( transport_event_tx: &mpsc::Sender, + writer: &mpsc::Sender, connection_id: ConnectionId, payload: &str, ) -> bool { match serde_json::from_str::(payload) { - Ok(message) => transport_event_tx - .send(TransportEvent::IncomingMessage { - connection_id, - message, - }) - .await - .is_ok(), + Ok(message) => enqueue_incoming_message(transport_event_tx, writer, connection_id, message), Err(err) => { error!("Failed to deserialize JSONRPCMessage: {err}"); true @@ -344,6 +361,50 @@ async fn forward_incoming_message( } } +fn enqueue_incoming_message( + transport_event_tx: &mpsc::Sender, + writer: &mpsc::Sender, + connection_id: ConnectionId, + message: JSONRPCMessage, +) -> bool { + let event = TransportEvent::IncomingMessage { + connection_id, + message, + }; + match transport_event_tx.try_send(event) { + Ok(()) => true, + Err(mpsc::error::TrySendError::Closed(_)) => false, + Err(mpsc::error::TrySendError::Full(TransportEvent::IncomingMessage { + connection_id, + message: JSONRPCMessage::Request(request), + })) => { + if writer + .try_send(overloaded_error_for_request(request)) + .is_err() + { + warn!("failed to enqueue overload response for connection: {connection_id:?}"); + } + true + } + Err(mpsc::error::TrySendError::Full(TransportEvent::IncomingMessage { .. })) => { + warn!("dropping non-request incoming message because processor queue is full"); + true + } + Err(mpsc::error::TrySendError::Full(_)) => true, + } +} + +fn overloaded_error_for_request(request: JSONRPCRequest) -> OutgoingMessage { + OutgoingMessage::Error(OutgoingError { + id: request.id, + error: JSONRPCErrorError { + code: OVERLOADED_ERROR_CODE, + message: "Server overloaded; retry later.".to_string(), + data: None, + }, + }) +} + fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option { let value = match serde_json::to_value(outgoing_message) { Ok(value) => value, @@ -362,7 +423,7 @@ fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option, + connections: &mut HashMap, envelope: OutgoingEnvelope, ) { match envelope { @@ -385,7 +446,7 @@ pub(crate) async fn route_outgoing_envelope( let target_connections: Vec = connections .iter() .filter_map(|(connection_id, connection_state)| { - if connection_state.session.initialized { + if connection_state.initialized { Some(*connection_id) } else { None @@ -416,7 +477,9 @@ pub(crate) fn has_initialized_connections( #[cfg(test)] mod tests { use super::*; + use crate::error_code::OVERLOADED_ERROR_CODE; use pretty_assertions::assert_eq; + use serde_json::json; #[test] fn app_server_transport_parses_stdio_listen_url() { @@ -456,4 +519,67 @@ mod tests { "unsupported --listen URL `http://127.0.0.1:1234`; expected `stdio://` or `ws://IP:PORT`" ); } + + #[tokio::test] + async fn enqueue_incoming_request_returns_overload_error_when_queue_is_full() { + let connection_id = ConnectionId(42); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + + let first_message = + JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + transport_event_tx + .send(TransportEvent::IncomingMessage { + connection_id, + message: first_message.clone(), + }) + .await + .expect("queue should accept first message"); + + let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { + id: codex_app_server_protocol::RequestId::Integer(7), + method: "config/read".to_string(), + params: Some(json!({ "includeLayers": false })), + }); + assert!(enqueue_incoming_message( + &transport_event_tx, + &writer_tx, + connection_id, + request + )); + + let queued_event = transport_event_rx + .recv() + .await + .expect("first event should stay queued"); + match queued_event { + TransportEvent::IncomingMessage { + connection_id: queued_connection_id, + message, + } => { + assert_eq!(queued_connection_id, connection_id); + assert_eq!(message, first_message); + } + _ => panic!("expected queued incoming message"), + } + + let overload = writer_rx + .recv() + .await + .expect("request should receive overload error"); + let overload_json = serde_json::to_value(overload).expect("serialize overload error"); + assert_eq!( + overload_json, + json!({ + "id": 7, + "error": { + "code": OVERLOADED_ERROR_CODE, + "message": "Server overloaded; retry later." + } + }) + ); + } } From f742019bec0c91db03663256b0c8c3f3df099e4a Mon Sep 17 00:00:00 2001 From: jif-oai Date: Tue, 10 Feb 2026 14:15:50 +0000 Subject: [PATCH 02/16] Discuss outbound routing fixes --- codex-rs/app-server/src/lib.rs | 21 +++- codex-rs/app-server/src/message_processor.rs | 16 +-- codex-rs/app-server/src/transport.rs | 123 ++++++++++++++----- 3 files changed, 120 insertions(+), 40 deletions(-) diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 0de4ddbf8e6..3d78dcdadf6 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -83,7 +83,10 @@ enum OutboundControlEvent { /// Remove state for a closed/disconnected connection. Closed { connection_id: ConnectionId }, /// Mark the connection as initialized, enabling broadcast delivery. - Initialized { connection_id: ConnectionId }, + Initialized { + connection_id: ConnectionId, + ready: oneshot::Sender<()>, + }, } fn config_warning_from_error( @@ -389,10 +392,14 @@ pub async fn run_main_with_transport( OutboundControlEvent::Closed { connection_id } => { outbound_connections.remove(&connection_id); } - OutboundControlEvent::Initialized { connection_id } => { + OutboundControlEvent::Initialized { + connection_id, + ready, + } => { if let Some(connection_state) = outbound_connections.get_mut(&connection_id) { connection_state.initialized = true; } + let _ = ready.send(()); } } } @@ -474,12 +481,20 @@ pub async fn run_main_with_transport( ) .await; if !was_initialized && connection_state.session.initialized { + let (ready_tx, ready_rx) = oneshot::channel(); let send_result = outbound_control_tx - .send(OutboundControlEvent::Initialized { connection_id }) + .send(OutboundControlEvent::Initialized { + connection_id, + ready: ready_tx, + }) .await; if send_result.is_err() { break; } + if ready_rx.await.is_err() { + break; + } + processor.send_initialize_notifications().await; } } JSONRPCMessage::Response(response) => { diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 26da44df311..fdb981cff1d 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -286,14 +286,6 @@ impl MessageProcessor { self.outgoing.send_response(request_id, response).await; session.initialized = true; - for notification in self.config_warnings.iter().cloned() { - self.outgoing - .send_server_notification(ServerNotification::ConfigWarning( - notification, - )) - .await; - } - return; } } @@ -381,6 +373,14 @@ impl MessageProcessor { self.codex_message_processor.thread_created_receiver() } + pub(crate) async fn send_initialize_notifications(&self) { + for notification in self.config_warnings.iter().cloned() { + self.outgoing + .send_server_notification(ServerNotification::ConfigWarning(notification)) + .await; + } + } + pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) { self.codex_message_processor .try_attach_thread_listener(thread_id) diff --git a/codex-rs/app-server/src/transport.rs b/codex-rs/app-server/src/transport.rs index 80a7c3424ec..e4155cf4635 100644 --- a/codex-rs/app-server/src/transport.rs +++ b/codex-rs/app-server/src/transport.rs @@ -6,7 +6,6 @@ use crate::outgoing_message::OutgoingError; use crate::outgoing_message::OutgoingMessage; use codex_app_server_protocol::JSONRPCErrorError; use codex_app_server_protocol::JSONRPCMessage; -use codex_app_server_protocol::JSONRPCRequest; use futures::SinkExt; use futures::StreamExt; use owo_colors::OwoColorize; @@ -198,7 +197,9 @@ pub(crate) async fn start_stdio_connection( &writer_tx_for_reader, connection_id, &line, - ) { + ) + .await + { break; } } @@ -317,7 +318,9 @@ async fn run_websocket_connection( &writer_tx_for_reader, connection_id, &text, - ) { + ) + .await + { break; } } @@ -346,14 +349,16 @@ async fn run_websocket_connection( .await; } -fn forward_incoming_message( +async fn forward_incoming_message( transport_event_tx: &mpsc::Sender, writer: &mpsc::Sender, connection_id: ConnectionId, payload: &str, ) -> bool { match serde_json::from_str::(payload) { - Ok(message) => enqueue_incoming_message(transport_event_tx, writer, connection_id, message), + Ok(message) => { + enqueue_incoming_message(transport_event_tx, writer, connection_id, message).await + } Err(err) => { error!("Failed to deserialize JSONRPCMessage: {err}"); true @@ -361,7 +366,7 @@ fn forward_incoming_message( } } -fn enqueue_incoming_message( +async fn enqueue_incoming_message( transport_event_tx: &mpsc::Sender, writer: &mpsc::Sender, connection_id: ConnectionId, @@ -379,32 +384,24 @@ fn enqueue_incoming_message( message: JSONRPCMessage::Request(request), })) => { if writer - .try_send(overloaded_error_for_request(request)) + .try_send(OutgoingMessage::Error(OutgoingError { + id: request.id, + error: JSONRPCErrorError { + code: OVERLOADED_ERROR_CODE, + message: "Server overloaded; retry later.".to_string(), + data: None, + }, + })) .is_err() { warn!("failed to enqueue overload response for connection: {connection_id:?}"); } true } - Err(mpsc::error::TrySendError::Full(TransportEvent::IncomingMessage { .. })) => { - warn!("dropping non-request incoming message because processor queue is full"); - true - } - Err(mpsc::error::TrySendError::Full(_)) => true, + Err(mpsc::error::TrySendError::Full(event)) => transport_event_tx.send(event).await.is_ok(), } } -fn overloaded_error_for_request(request: JSONRPCRequest) -> OutgoingMessage { - OutgoingMessage::Error(OutgoingError { - id: request.id, - error: JSONRPCErrorError { - code: OVERLOADED_ERROR_CODE, - message: "Server overloaded; retry later.".to_string(), - data: None, - }, - }) -} - fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option { let value = match serde_json::to_value(outgoing_message) { Ok(value) => value, @@ -544,12 +541,9 @@ mod tests { method: "config/read".to_string(), params: Some(json!({ "includeLayers": false })), }); - assert!(enqueue_incoming_message( - &transport_event_tx, - &writer_tx, - connection_id, - request - )); + assert!( + enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request).await + ); let queued_event = transport_event_rx .recv() @@ -582,4 +576,75 @@ mod tests { }) ); } + + #[tokio::test] + async fn enqueue_incoming_response_waits_instead_of_dropping_when_queue_is_full() { + let connection_id = ConnectionId(42); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1); + let (writer_tx, _writer_rx) = mpsc::channel(1); + + let first_message = + JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + transport_event_tx + .send(TransportEvent::IncomingMessage { + connection_id, + message: first_message.clone(), + }) + .await + .expect("queue should accept first message"); + + let response = JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse { + id: codex_app_server_protocol::RequestId::Integer(7), + result: json!({"ok": true}), + }); + let transport_event_tx_for_enqueue = transport_event_tx.clone(); + let writer_tx_for_enqueue = writer_tx.clone(); + let enqueue_handle = tokio::spawn(async move { + enqueue_incoming_message( + &transport_event_tx_for_enqueue, + &writer_tx_for_enqueue, + connection_id, + response, + ) + .await + }); + + let queued_event = transport_event_rx + .recv() + .await + .expect("first event should be dequeued"); + match queued_event { + TransportEvent::IncomingMessage { + connection_id: queued_connection_id, + message, + } => { + assert_eq!(queued_connection_id, connection_id); + assert_eq!(message, first_message); + } + _ => panic!("expected queued incoming message"), + } + + let enqueue_result = enqueue_handle.await.expect("enqueue task should not panic"); + assert!(enqueue_result); + + let forwarded_event = transport_event_rx + .recv() + .await + .expect("response should be forwarded instead of dropped"); + match forwarded_event { + TransportEvent::IncomingMessage { + connection_id: queued_connection_id, + message: + JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse { id, result }), + } => { + assert_eq!(queued_connection_id, connection_id); + assert_eq!(id, codex_app_server_protocol::RequestId::Integer(7)); + assert_eq!(result, json!({"ok": true})); + } + _ => panic!("expected forwarded response message"), + } + } } From bd89f60348adda35e19685550174977749f15a8e Mon Sep 17 00:00:00 2001 From: jif-oai Date: Tue, 10 Feb 2026 14:24:07 +0000 Subject: [PATCH 03/16] Plan outbound router ordering fixes --- codex-rs/app-server/src/lib.rs | 18 +++++++++++++++++- codex-rs/app-server/src/transport.rs | 8 ++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 3d78dcdadf6..fed9314f259 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -366,6 +366,7 @@ pub async fn run_main_with_transport( } } + let transport_event_tx_for_outbound = transport_event_tx.clone(); let outbound_handle = tokio::spawn(async move { let mut outbound_connections = HashMap::::new(); loop { @@ -374,7 +375,22 @@ pub async fn run_main_with_transport( let Some(envelope) = envelope else { break; }; - route_outgoing_envelope(&mut outbound_connections, envelope).await; + let disconnected_connections = + route_outgoing_envelope(&mut outbound_connections, envelope).await; + let mut should_exit = false; + for connection_id in disconnected_connections { + if transport_event_tx_for_outbound + .send(TransportEvent::ConnectionClosed { connection_id }) + .await + .is_err() + { + should_exit = true; + break; + } + } + if should_exit { + break; + } } event = outbound_control_rx.recv() => { let Some(event) = event else { diff --git a/codex-rs/app-server/src/transport.rs b/codex-rs/app-server/src/transport.rs index e4155cf4635..ae25720474c 100644 --- a/codex-rs/app-server/src/transport.rs +++ b/codex-rs/app-server/src/transport.rs @@ -422,7 +422,8 @@ fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option, envelope: OutgoingEnvelope, -) { +) -> Vec { + let mut disconnected = Vec::new(); match envelope { OutgoingEnvelope::ToConnection { connection_id, @@ -433,10 +434,11 @@ pub(crate) async fn route_outgoing_envelope( "dropping message for disconnected connection: {:?}", connection_id ); - return; + return disconnected; }; if connection_state.writer.send(message).await.is_err() { connections.remove(&connection_id); + disconnected.push(connection_id); } } OutgoingEnvelope::Broadcast { message } => { @@ -457,10 +459,12 @@ pub(crate) async fn route_outgoing_envelope( }; if connection_state.writer.send(message.clone()).await.is_err() { connections.remove(&connection_id); + disconnected.push(connection_id); } } } } + disconnected } pub(crate) fn has_initialized_connections( From 9a64991e6908b324376480718e8e037d2c27cefe Mon Sep 17 00:00:00 2001 From: jif-oai Date: Tue, 10 Feb 2026 14:38:45 +0000 Subject: [PATCH 04/16] Plan outbound routing fixes --- codex-rs/app-server/src/lib.rs | 45 ++++++-------------- codex-rs/app-server/src/message_processor.rs | 4 ++ codex-rs/app-server/src/transport.rs | 45 ++++++++++++-------- 3 files changed, 44 insertions(+), 50 deletions(-) diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index fed9314f259..b20aaa53128 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -13,6 +13,7 @@ use std::io::ErrorKind; use std::io::Result as IoResult; use std::path::PathBuf; use std::sync::Arc; +use std::sync::atomic::AtomicBool; use crate::message_processor::MessageProcessor; use crate::message_processor::MessageProcessorArgs; @@ -70,23 +71,18 @@ pub use crate::transport::AppServerTransport; /// - outbound loop: performs potentially slow writes to per-connection writers /// /// `OutboundControlEvent` keeps those loops coordinated without sharing mutable -/// connection state directly. In particular, the outbound loop needs to know: -/// - when a connection opens/closes so it can route messages correctly -/// - when a connection becomes initialized so broadcast semantics remain unchanged +/// connection state directly. In particular, the outbound loop needs to know +/// when a connection opens/closes so it can route messages correctly. enum OutboundControlEvent { /// Register a new writer for an opened connection. Opened { connection_id: ConnectionId, writer: mpsc::Sender, + initialized: Arc, ready: oneshot::Sender<()>, }, /// Remove state for a closed/disconnected connection. Closed { connection_id: ConnectionId }, - /// Mark the connection as initialized, enabling broadcast delivery. - Initialized { - connection_id: ConnectionId, - ready: oneshot::Sender<()>, - }, } fn config_warning_from_error( @@ -400,23 +396,18 @@ pub async fn run_main_with_transport( OutboundControlEvent::Opened { connection_id, writer, + initialized, ready, } => { - outbound_connections.insert(connection_id, OutboundConnectionState::new(writer)); + outbound_connections.insert( + connection_id, + OutboundConnectionState::new(writer, initialized), + ); let _ = ready.send(()); } OutboundControlEvent::Closed { connection_id } => { outbound_connections.remove(&connection_id); } - OutboundControlEvent::Initialized { - connection_id, - ready, - } => { - if let Some(connection_state) = outbound_connections.get_mut(&connection_id) { - connection_state.initialized = true; - } - let _ = ready.send(()); - } } } } @@ -451,11 +442,13 @@ pub async fn run_main_with_transport( }; match event { TransportEvent::ConnectionOpened { connection_id, writer } => { + let outbound_initialized = Arc::new(AtomicBool::new(false)); let (ready_tx, ready_rx) = oneshot::channel(); if outbound_control_tx .send(OutboundControlEvent::Opened { connection_id, writer: writer.clone(), + initialized: Arc::clone(&outbound_initialized), ready: ready_tx, }) .await @@ -466,7 +459,7 @@ pub async fn run_main_with_transport( if ready_rx.await.is_err() { break; } - connections.insert(connection_id, ConnectionState::new()); + connections.insert(connection_id, ConnectionState::new(outbound_initialized)); } TransportEvent::ConnectionClosed { connection_id } => { if outbound_control_tx @@ -494,22 +487,10 @@ pub async fn run_main_with_transport( connection_id, request, &mut connection_state.session, + &connection_state.outbound_initialized, ) .await; if !was_initialized && connection_state.session.initialized { - let (ready_tx, ready_rx) = oneshot::channel(); - let send_result = outbound_control_tx - .send(OutboundControlEvent::Initialized { - connection_id, - ready: ready_tx, - }) - .await; - if send_result.is_err() { - break; - } - if ready_rx.await.is_err() { - break; - } processor.send_initialize_notifications().await; } } diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index fdb981cff1d..b962d52902e 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -1,6 +1,8 @@ use std::path::PathBuf; use std::sync::Arc; use std::sync::RwLock; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; use crate::codex_message_processor::CodexMessageProcessor; use crate::codex_message_processor::CodexMessageProcessorArgs; @@ -191,6 +193,7 @@ impl MessageProcessor { connection_id: ConnectionId, request: JSONRPCRequest, session: &mut ConnectionSessionState, + outbound_initialized: &AtomicBool, ) { let request_id = ConnectionRequestId { connection_id, @@ -286,6 +289,7 @@ impl MessageProcessor { self.outgoing.send_response(request_id, response).await; session.initialized = true; + outbound_initialized.store(true, Ordering::Release); return; } } diff --git a/codex-rs/app-server/src/transport.rs b/codex-rs/app-server/src/transport.rs index ae25720474c..702c28ed695 100644 --- a/codex-rs/app-server/src/transport.rs +++ b/codex-rs/app-server/src/transport.rs @@ -17,6 +17,7 @@ use std::io::Result as IoResult; use std::net::SocketAddr; use std::str::FromStr; use std::sync::Arc; +use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; use tokio::io::AsyncBufReadExt; @@ -143,27 +144,29 @@ pub(crate) enum TransportEvent { } pub(crate) struct ConnectionState { + pub(crate) outbound_initialized: Arc, pub(crate) session: ConnectionSessionState, } impl ConnectionState { - pub(crate) fn new() -> Self { + pub(crate) fn new(outbound_initialized: Arc) -> Self { Self { + outbound_initialized, session: ConnectionSessionState::default(), } } } pub(crate) struct OutboundConnectionState { + pub(crate) initialized: Arc, pub(crate) writer: mpsc::Sender, - pub(crate) initialized: bool, } impl OutboundConnectionState { - pub(crate) fn new(writer: mpsc::Sender) -> Self { + pub(crate) fn new(writer: mpsc::Sender, initialized: Arc) -> Self { Self { + initialized, writer, - initialized: false, } } } @@ -383,20 +386,26 @@ async fn enqueue_incoming_message( connection_id, message: JSONRPCMessage::Request(request), })) => { - if writer - .try_send(OutgoingMessage::Error(OutgoingError { - id: request.id, - error: JSONRPCErrorError { - code: OVERLOADED_ERROR_CODE, - message: "Server overloaded; retry later.".to_string(), - data: None, - }, - })) - .is_err() - { - warn!("failed to enqueue overload response for connection: {connection_id:?}"); + let overload_error = OutgoingMessage::Error(OutgoingError { + id: request.id, + error: JSONRPCErrorError { + code: OVERLOADED_ERROR_CODE, + message: "Server overloaded; retry later.".to_string(), + data: None, + }, + }); + match writer.try_send(overload_error) { + Ok(()) => true, + Err(mpsc::error::TrySendError::Closed(_)) => false, + Err(mpsc::error::TrySendError::Full(overload_error)) => { + if writer.send(overload_error).await.is_err() { + warn!("failed to send overload response for connection: {connection_id:?}"); + false + } else { + true + } + } } - true } Err(mpsc::error::TrySendError::Full(event)) => transport_event_tx.send(event).await.is_ok(), } @@ -445,7 +454,7 @@ pub(crate) async fn route_outgoing_envelope( let target_connections: Vec = connections .iter() .filter_map(|(connection_id, connection_state)| { - if connection_state.initialized { + if connection_state.initialized.load(Ordering::Acquire) { Some(*connection_id) } else { None From 66f9f631a219e303489c1a455fbb1d4a0d49c58e Mon Sep 17 00:00:00 2001 From: jif-oai Date: Tue, 10 Feb 2026 14:46:51 +0000 Subject: [PATCH 05/16] Fix app-server outbound connection --- codex-rs/app-server/src/lib.rs | 52 ++++++++++------------ codex-rs/app-server/src/transport.rs | 64 +++++++++++++++++++++++++--- 2 files changed, 79 insertions(+), 37 deletions(-) diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index b20aaa53128..c09595428c4 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -39,7 +39,6 @@ use codex_core::config_loader::ConfigLoadError; use codex_core::config_loader::TextRange as CoreTextRange; use codex_feedback::CodexFeedback; use tokio::sync::mpsc; -use tokio::sync::oneshot; use tokio::task::JoinHandle; use toml::Value as TomlValue; use tracing::error; @@ -79,7 +78,6 @@ enum OutboundControlEvent { connection_id: ConnectionId, writer: mpsc::Sender, initialized: Arc, - ready: oneshot::Sender<()>, }, /// Remove state for a closed/disconnected connection. Closed { connection_id: ConnectionId }, @@ -367,6 +365,27 @@ pub async fn run_main_with_transport( let mut outbound_connections = HashMap::::new(); loop { tokio::select! { + biased; + event = outbound_control_rx.recv() => { + let Some(event) = event else { + break; + }; + match event { + OutboundControlEvent::Opened { + connection_id, + writer, + initialized, + } => { + outbound_connections.insert( + connection_id, + OutboundConnectionState::new(writer, initialized), + ); + } + OutboundControlEvent::Closed { connection_id } => { + outbound_connections.remove(&connection_id); + } + } + } envelope = outgoing_rx.recv() => { let Some(envelope) = envelope else { break; @@ -388,28 +407,6 @@ pub async fn run_main_with_transport( break; } } - event = outbound_control_rx.recv() => { - let Some(event) = event else { - break; - }; - match event { - OutboundControlEvent::Opened { - connection_id, - writer, - initialized, - ready, - } => { - outbound_connections.insert( - connection_id, - OutboundConnectionState::new(writer, initialized), - ); - let _ = ready.send(()); - } - OutboundControlEvent::Closed { connection_id } => { - outbound_connections.remove(&connection_id); - } - } - } } } info!("outbound router task exited (channel closed)"); @@ -443,22 +440,17 @@ pub async fn run_main_with_transport( match event { TransportEvent::ConnectionOpened { connection_id, writer } => { let outbound_initialized = Arc::new(AtomicBool::new(false)); - let (ready_tx, ready_rx) = oneshot::channel(); if outbound_control_tx .send(OutboundControlEvent::Opened { connection_id, - writer: writer.clone(), + writer, initialized: Arc::clone(&outbound_initialized), - ready: ready_tx, }) .await .is_err() { break; } - if ready_rx.await.is_err() { - break; - } connections.insert(connection_id, ConnectionState::new(outbound_initialized)); } TransportEvent::ConnectionClosed { connection_id } => { diff --git a/codex-rs/app-server/src/transport.rs b/codex-rs/app-server/src/transport.rs index 702c28ed695..d70eb3ffd8d 100644 --- a/codex-rs/app-server/src/transport.rs +++ b/codex-rs/app-server/src/transport.rs @@ -397,13 +397,12 @@ async fn enqueue_incoming_message( match writer.try_send(overload_error) { Ok(()) => true, Err(mpsc::error::TrySendError::Closed(_)) => false, - Err(mpsc::error::TrySendError::Full(overload_error)) => { - if writer.send(overload_error).await.is_err() { - warn!("failed to send overload response for connection: {connection_id:?}"); - false - } else { - true - } + Err(mpsc::error::TrySendError::Full(_overload_error)) => { + warn!( + "dropping overload response for connection {:?}: outbound queue is full", + connection_id + ); + true } } } @@ -660,4 +659,55 @@ mod tests { _ => panic!("expected forwarded response message"), } } + + #[tokio::test] + async fn enqueue_incoming_request_does_not_block_when_writer_queue_is_full() { + let connection_id = ConnectionId(42); + let (transport_event_tx, _transport_event_rx) = mpsc::channel(1); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + + transport_event_tx + .send(TransportEvent::IncomingMessage { + connection_id, + message: JSONRPCMessage::Notification( + codex_app_server_protocol::JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }, + ), + }) + .await + .expect("transport queue should accept first message"); + + writer_tx + .send(OutgoingMessage::Notification( + crate::outgoing_message::OutgoingNotification { + method: "queued".to_string(), + params: None, + }, + )) + .await + .expect("writer queue should accept first message"); + + let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { + id: codex_app_server_protocol::RequestId::Integer(7), + method: "config/read".to_string(), + params: Some(json!({ "includeLayers": false })), + }); + + let enqueue_result = tokio::time::timeout( + std::time::Duration::from_millis(100), + enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request), + ) + .await + .expect("enqueue should not block while writer queue is full"); + assert!(enqueue_result); + + let queued_outgoing = writer_rx + .recv() + .await + .expect("writer queue should still contain original message"); + let queued_json = serde_json::to_value(queued_outgoing).expect("serialize queued message"); + assert_eq!(queued_json, json!({ "method": "queued" })); + } } From 5cba950f2514059b869b3504a8e8d6bb6aa17ee3 Mon Sep 17 00:00:00 2001 From: jif-oai Date: Tue, 10 Feb 2026 15:00:35 +0000 Subject: [PATCH 06/16] Review app-server request loop --- codex-rs/app-server/src/lib.rs | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index c09595428c4..0b23fed74ed 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -9,6 +9,7 @@ use codex_core::config_loader::CloudRequirementsLoader; use codex_core::config_loader::ConfigLayerStackOrdering; use codex_core::config_loader::LoaderOverrides; use std::collections::HashMap; +use std::collections::VecDeque; use std::io::ErrorKind; use std::io::Result as IoResult; use std::path::PathBuf; @@ -363,6 +364,7 @@ pub async fn run_main_with_transport( let transport_event_tx_for_outbound = transport_event_tx.clone(); let outbound_handle = tokio::spawn(async move { let mut outbound_connections = HashMap::::new(); + let mut pending_closed_connections = VecDeque::::new(); loop { tokio::select! { biased; @@ -392,20 +394,23 @@ pub async fn run_main_with_transport( }; let disconnected_connections = route_outgoing_envelope(&mut outbound_connections, envelope).await; - let mut should_exit = false; - for connection_id in disconnected_connections { - if transport_event_tx_for_outbound - .send(TransportEvent::ConnectionClosed { connection_id }) - .await - .is_err() - { - should_exit = true; - break; - } + pending_closed_connections.extend(disconnected_connections); + } + } + + while let Some(connection_id) = pending_closed_connections.front().copied() { + match transport_event_tx_for_outbound + .try_send(TransportEvent::ConnectionClosed { connection_id }) + { + Ok(()) => { + pending_closed_connections.pop_front(); } - if should_exit { + Err(mpsc::error::TrySendError::Full(_)) => { break; } + Err(mpsc::error::TrySendError::Closed(_)) => { + return; + } } } } From 7f6daba35d649c8f75be6dfc5ad263bec8d34079 Mon Sep 17 00:00:00 2001 From: Max Johnson Date: Tue, 10 Feb 2026 12:56:11 -0800 Subject: [PATCH 07/16] Reapply "Add app-server transport layer with websocket support (#10693)" (#11323) This reverts commit 47356ff83c38305e08c4a075fde4624cb04a7aaf. --- codex-rs/Cargo.lock | 5 + codex-rs/app-server/Cargo.toml | 8 +- codex-rs/app-server/README.md | 13 +- .../app-server/src/bespoke_event_handling.rs | 70 ++- .../app-server/src/codex_message_processor.rs | 420 ++++++++++------ codex-rs/app-server/src/lib.rs | 174 ++++--- codex-rs/app-server/src/main.rs | 21 +- codex-rs/app-server/src/message_processor.rs | 120 +++-- codex-rs/app-server/src/outgoing_message.rs | 168 ++++++- codex-rs/app-server/src/transport.rs | 459 ++++++++++++++++++ .../app-server/tests/common/mcp_process.rs | 2 +- .../suite/v2/connection_handling_websocket.rs | 263 ++++++++++ codex-rs/app-server/tests/suite/v2/mod.rs | 1 + codex-rs/app-server/tests/suite/v2/review.rs | 6 +- codex-rs/cli/src/main.rs | 47 +- codex-rs/core/tests/suite/review.rs | 19 - 16 files changed, 1455 insertions(+), 341 deletions(-) create mode 100644 codex-rs/app-server/src/transport.rs create mode 100644 codex-rs/app-server/tests/suite/v2/connection_handling_websocket.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 92716cc41bc..1973f8923ef 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1348,6 +1348,7 @@ dependencies = [ "axum", "base64 0.22.1", "chrono", + "clap", "codex-app-server-protocol", "codex-arg0", "codex-backend-client", @@ -1362,9 +1363,12 @@ dependencies = [ "codex-protocol", "codex-rmcp-client", "codex-utils-absolute-path", + "codex-utils-cargo-bin", "codex-utils-json-to-toml", "core_test_support", + "futures", "os_info", + "owo-colors", "pretty_assertions", "rmcp", "serde", @@ -1374,6 +1378,7 @@ dependencies = [ "tempfile", "time", "tokio", + "tokio-tungstenite", "toml 0.9.11+spec-1.1.0", "tracing", "tracing-subscriber", diff --git a/codex-rs/app-server/Cargo.toml b/codex-rs/app-server/Cargo.toml index a4e848213aa..f68e0787759 100644 --- a/codex-rs/app-server/Cargo.toml +++ b/codex-rs/app-server/Cargo.toml @@ -30,8 +30,12 @@ codex-protocol = { workspace = true } codex-app-server-protocol = { workspace = true } codex-feedback = { workspace = true } codex-rmcp-client = { workspace = true } +codex-utils-absolute-path = { workspace = true } codex-utils-json-to-toml = { workspace = true } chrono = { workspace = true } +clap = { workspace = true, features = ["derive"] } +futures = { workspace = true } +owo-colors = { workspace = true, features = ["supports-colors"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tempfile = { workspace = true } @@ -44,6 +48,7 @@ tokio = { workspace = true, features = [ "rt-multi-thread", "signal", ] } +tokio-tungstenite = { workspace = true } tracing = { workspace = true, features = ["log"] } tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] } uuid = { workspace = true, features = ["serde", "v7"] } @@ -57,8 +62,8 @@ axum = { workspace = true, default-features = false, features = [ ] } base64 = { workspace = true } codex-execpolicy = { workspace = true } -codex-utils-absolute-path = { workspace = true } core_test_support = { workspace = true } +codex-utils-cargo-bin = { workspace = true } os_info = { workspace = true } pretty_assertions = { workspace = true } rmcp = { workspace = true, default-features = false, features = [ @@ -66,5 +71,6 @@ rmcp = { workspace = true, default-features = false, features = [ "transport-streamable-http-server", ] } serial_test = { workspace = true } +tokio-tungstenite = { workspace = true } wiremock = { workspace = true } shlex = { workspace = true } diff --git a/codex-rs/app-server/README.md b/codex-rs/app-server/README.md index 7608262077c..90b37c9b66d 100644 --- a/codex-rs/app-server/README.md +++ b/codex-rs/app-server/README.md @@ -19,7 +19,14 @@ ## Protocol -Similar to [MCP](https://modelcontextprotocol.io/), `codex app-server` supports bidirectional communication, streaming JSONL over stdio. The protocol is JSON-RPC 2.0, though the `"jsonrpc":"2.0"` header is omitted. +Similar to [MCP](https://modelcontextprotocol.io/), `codex app-server` supports bidirectional communication using JSON-RPC 2.0 messages (with the `"jsonrpc":"2.0"` header omitted on the wire). + +Supported transports: + +- stdio (`--listen stdio://`, default): newline-delimited JSON (JSONL) +- websocket (`--listen ws://IP:PORT`): one JSON-RPC message per websocket text frame (**experimental / unsupported**) + +Websocket transport is currently experimental and unsupported. Do not rely on it for production workloads. ## Message Schema @@ -42,7 +49,7 @@ Use the thread APIs to create, list, or archive conversations. Drive a conversat ## Lifecycle Overview -- Initialize once: Immediately after launching the codex app-server process, send an `initialize` request with your client metadata, then emit an `initialized` notification. Any other request before this handshake gets rejected. +- Initialize once per connection: Immediately after opening a transport connection, send an `initialize` request with your client metadata, then emit an `initialized` notification. Any other request on that connection before this handshake gets rejected. - Start (or resume) a thread: Call `thread/start` to open a fresh conversation. The response returns the thread object and you’ll also get a `thread/started` notification. If you’re continuing an existing conversation, call `thread/resume` with its ID instead. If you want to branch from an existing conversation, call `thread/fork` to create a new thread id with copied history. - Begin a turn: To send user input, call `turn/start` with the target `threadId` and the user's input. Optional fields let you override model, cwd, sandbox policy, etc. This immediately returns the new turn object and triggers a `turn/started` notification. - Stream events: After `turn/start`, keep reading JSON-RPC notifications on stdout. You’ll see `item/started`, `item/completed`, deltas like `item/agentMessage/delta`, tool progress, etc. These represent streaming model output plus any side effects (commands, tool calls, reasoning notes). @@ -50,7 +57,7 @@ Use the thread APIs to create, list, or archive conversations. Drive a conversat ## Initialization -Clients must send a single `initialize` request before invoking any other method, then acknowledge with an `initialized` notification. The server returns the user agent string it will present to upstream services; subsequent requests issued before initialization receive a `"Not initialized"` error, and repeated `initialize` calls receive an `"Already initialized"` error. +Clients must send a single `initialize` request per transport connection before invoking any other method on that connection, then acknowledge with an `initialized` notification. The server returns the user agent string it will present to upstream services; subsequent requests issued before initialization receive a `"Not initialized"` error, and repeated `initialize` calls on the same connection receive an `"Already initialized"` error. `initialize.params.capabilities` also supports per-connection notification opt-out via `optOutNotificationMethods`, which is a list of exact method names to suppress for that connection. Matching is exact (no wildcards/prefixes). Unknown method names are accepted and ignored. diff --git a/codex-rs/app-server/src/bespoke_event_handling.rs b/codex-rs/app-server/src/bespoke_event_handling.rs index 73647b1c08d..a9c330f1b63 100644 --- a/codex-rs/app-server/src/bespoke_event_handling.rs +++ b/codex-rs/app-server/src/bespoke_event_handling.rs @@ -1115,7 +1115,7 @@ pub(crate) async fn apply_bespoke_event_handling( ), data: None, }; - outgoing.send_error(request_id, error).await; + outgoing.send_error(request_id.clone(), error).await; return; } } @@ -1129,7 +1129,7 @@ pub(crate) async fn apply_bespoke_event_handling( ), data: None, }; - outgoing.send_error(request_id, error).await; + outgoing.send_error(request_id.clone(), error).await; return; } }; @@ -1894,6 +1894,7 @@ async fn construct_mcp_tool_call_end_notification( mod tests { use super::*; use crate::CHANNEL_CAPACITY; + use crate::outgoing_message::OutgoingEnvelope; use crate::outgoing_message::OutgoingMessage; use crate::outgoing_message::OutgoingMessageSender; use anyhow::Result; @@ -1923,6 +1924,21 @@ mod tests { Arc::new(Mutex::new(HashMap::new())) } + async fn recv_broadcast_message( + rx: &mut mpsc::Receiver, + ) -> Result { + let envelope = rx + .recv() + .await + .ok_or_else(|| anyhow!("should send one message"))?; + match envelope { + OutgoingEnvelope::Broadcast { message } => Ok(message), + OutgoingEnvelope::ToConnection { connection_id, .. } => { + bail!("unexpected targeted message for connection {connection_id:?}") + } + } + } + #[test] fn file_change_accept_for_session_maps_to_approved_for_session() { let (decision, completion_status) = @@ -2024,10 +2040,7 @@ mod tests { ) .await; - let msg = rx - .recv() - .await - .ok_or_else(|| anyhow!("should send one notification"))?; + let msg = recv_broadcast_message(&mut rx).await?; match msg { OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { assert_eq!(n.turn.id, event_turn_id); @@ -2066,10 +2079,7 @@ mod tests { ) .await; - let msg = rx - .recv() - .await - .ok_or_else(|| anyhow!("should send one notification"))?; + let msg = recv_broadcast_message(&mut rx).await?; match msg { OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { assert_eq!(n.turn.id, event_turn_id); @@ -2108,10 +2118,7 @@ mod tests { ) .await; - let msg = rx - .recv() - .await - .ok_or_else(|| anyhow!("should send one notification"))?; + let msg = recv_broadcast_message(&mut rx).await?; match msg { OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { assert_eq!(n.turn.id, event_turn_id); @@ -2160,10 +2167,7 @@ mod tests { ) .await; - let msg = rx - .recv() - .await - .ok_or_else(|| anyhow!("should send one notification"))?; + let msg = recv_broadcast_message(&mut rx).await?; match msg { OutgoingMessage::AppServerNotification(ServerNotification::TurnPlanUpdated(n)) => { assert_eq!(n.thread_id, conversation_id.to_string()); @@ -2231,10 +2235,7 @@ mod tests { ) .await; - let first = rx - .recv() - .await - .ok_or_else(|| anyhow!("expected usage notification"))?; + let first = recv_broadcast_message(&mut rx).await?; match first { OutgoingMessage::AppServerNotification( ServerNotification::ThreadTokenUsageUpdated(payload), @@ -2250,10 +2251,7 @@ mod tests { other => bail!("unexpected notification: {other:?}"), } - let second = rx - .recv() - .await - .ok_or_else(|| anyhow!("expected rate limit notification"))?; + let second = recv_broadcast_message(&mut rx).await?; match second { OutgoingMessage::AppServerNotification( ServerNotification::AccountRateLimitsUpdated(payload), @@ -2390,10 +2388,7 @@ mod tests { .await; // Verify: A turn 1 - let msg = rx - .recv() - .await - .ok_or_else(|| anyhow!("should send first notification"))?; + let msg = recv_broadcast_message(&mut rx).await?; match msg { OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { assert_eq!(n.turn.id, a_turn1); @@ -2411,10 +2406,7 @@ mod tests { } // Verify: B turn 1 - let msg = rx - .recv() - .await - .ok_or_else(|| anyhow!("should send second notification"))?; + let msg = recv_broadcast_message(&mut rx).await?; match msg { OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { assert_eq!(n.turn.id, b_turn1); @@ -2432,10 +2424,7 @@ mod tests { } // Verify: A turn 2 - let msg = rx - .recv() - .await - .ok_or_else(|| anyhow!("should send third notification"))?; + let msg = recv_broadcast_message(&mut rx).await?; match msg { OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { assert_eq!(n.turn.id, a_turn2); @@ -2601,10 +2590,7 @@ mod tests { ) .await; - let msg = rx - .recv() - .await - .ok_or_else(|| anyhow!("should send one notification"))?; + let msg = recv_broadcast_message(&mut rx).await?; match msg { OutgoingMessage::AppServerNotification(ServerNotification::TurnDiffUpdated( notification, diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index e8b921980f4..6e194b4512e 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -3,6 +3,8 @@ use crate::error_code::INTERNAL_ERROR_CODE; use crate::error_code::INVALID_REQUEST_ERROR_CODE; use crate::fuzzy_file_search::run_fuzzy_file_search; use crate::models::supported_models; +use crate::outgoing_message::ConnectionId; +use crate::outgoing_message::ConnectionRequestId; use crate::outgoing_message::OutgoingMessageSender; use crate::outgoing_message::OutgoingNotification; use chrono::DateTime; @@ -83,7 +85,6 @@ use codex_app_server_protocol::NewConversationParams; use codex_app_server_protocol::NewConversationResponse; use codex_app_server_protocol::RemoveConversationListenerParams; use codex_app_server_protocol::RemoveConversationSubscriptionResponse; -use codex_app_server_protocol::RequestId; use codex_app_server_protocol::ResumeConversationParams; use codex_app_server_protocol::ResumeConversationResponse; use codex_app_server_protocol::ReviewDelivery as ApiReviewDelivery; @@ -252,10 +253,10 @@ use uuid::Uuid; use crate::filters::compute_source_filters; use crate::filters::source_kind_matches; -type PendingInterruptQueue = Vec<(RequestId, ApiVersion)>; +type PendingInterruptQueue = Vec<(ConnectionRequestId, ApiVersion)>; pub(crate) type PendingInterrupts = Arc>>; -pub(crate) type PendingRollbacks = Arc>>; +pub(crate) type PendingRollbacks = Arc>>; /// Per-conversation accumulation of the latest states e.g. error message while a turn runs. #[derive(Default, Clone)] @@ -486,103 +487,137 @@ impl CodexMessageProcessor { Ok((review_request, hint)) } - pub async fn process_request(&mut self, request: ClientRequest) { + pub async fn process_request(&mut self, connection_id: ConnectionId, request: ClientRequest) { + let to_connection_request_id = |request_id| ConnectionRequestId { + connection_id, + request_id, + }; + match request { ClientRequest::Initialize { .. } => { panic!("Initialize should be handled in MessageProcessor"); } // === v2 Thread/Turn APIs === ClientRequest::ThreadStart { request_id, params } => { - self.thread_start(request_id, params).await; + self.thread_start(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadResume { request_id, params } => { - self.thread_resume(request_id, params).await; + self.thread_resume(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadFork { request_id, params } => { - self.thread_fork(request_id, params).await; + self.thread_fork(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadArchive { request_id, params } => { - self.thread_archive(request_id, params).await; + self.thread_archive(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadSetName { request_id, params } => { - self.thread_set_name(request_id, params).await; + self.thread_set_name(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadUnarchive { request_id, params } => { - self.thread_unarchive(request_id, params).await; + self.thread_unarchive(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadCompactStart { request_id, params } => { - self.thread_compact_start(request_id, params).await; + self.thread_compact_start(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadBackgroundTerminalsClean { request_id, params } => { - self.thread_background_terminals_clean(request_id, params) - .await; + self.thread_background_terminals_clean( + to_connection_request_id(request_id), + params, + ) + .await; } ClientRequest::ThreadRollback { request_id, params } => { - self.thread_rollback(request_id, params).await; + self.thread_rollback(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadList { request_id, params } => { - self.thread_list(request_id, params).await; + self.thread_list(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadLoadedList { request_id, params } => { - self.thread_loaded_list(request_id, params).await; + self.thread_loaded_list(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadRead { request_id, params } => { - self.thread_read(request_id, params).await; + self.thread_read(to_connection_request_id(request_id), params) + .await; } ClientRequest::SkillsList { request_id, params } => { - self.skills_list(request_id, params).await; + self.skills_list(to_connection_request_id(request_id), params) + .await; } ClientRequest::SkillsRemoteRead { request_id, params } => { - self.skills_remote_read(request_id, params).await; + self.skills_remote_read(to_connection_request_id(request_id), params) + .await; } ClientRequest::SkillsRemoteWrite { request_id, params } => { - self.skills_remote_write(request_id, params).await; + self.skills_remote_write(to_connection_request_id(request_id), params) + .await; } ClientRequest::AppsList { request_id, params } => { - self.apps_list(request_id, params).await; + self.apps_list(to_connection_request_id(request_id), params) + .await; } ClientRequest::SkillsConfigWrite { request_id, params } => { - self.skills_config_write(request_id, params).await; + self.skills_config_write(to_connection_request_id(request_id), params) + .await; } ClientRequest::TurnStart { request_id, params } => { - self.turn_start(request_id, params).await; + self.turn_start(to_connection_request_id(request_id), params) + .await; } ClientRequest::TurnSteer { request_id, params } => { - self.turn_steer(request_id, params).await; + self.turn_steer(to_connection_request_id(request_id), params) + .await; } ClientRequest::TurnInterrupt { request_id, params } => { - self.turn_interrupt(request_id, params).await; + self.turn_interrupt(to_connection_request_id(request_id), params) + .await; } ClientRequest::ReviewStart { request_id, params } => { - self.review_start(request_id, params).await; + self.review_start(to_connection_request_id(request_id), params) + .await; } ClientRequest::NewConversation { request_id, params } => { // Do not tokio::spawn() to process new_conversation() // asynchronously because we need to ensure the conversation is // created before processing any subsequent messages. - self.process_new_conversation(request_id, params).await; + self.process_new_conversation(to_connection_request_id(request_id), params) + .await; } ClientRequest::GetConversationSummary { request_id, params } => { - self.get_thread_summary(request_id, params).await; + self.get_thread_summary(to_connection_request_id(request_id), params) + .await; } ClientRequest::ListConversations { request_id, params } => { - self.handle_list_conversations(request_id, params).await; + self.handle_list_conversations(to_connection_request_id(request_id), params) + .await; } ClientRequest::ModelList { request_id, params } => { let outgoing = self.outgoing.clone(); let thread_manager = self.thread_manager.clone(); let config = self.config.clone(); + let request_id = to_connection_request_id(request_id); tokio::spawn(async move { Self::list_models(outgoing, thread_manager, config, request_id, params).await; }); } ClientRequest::ExperimentalFeatureList { request_id, params } => { - self.experimental_feature_list(request_id, params).await; + self.experimental_feature_list(to_connection_request_id(request_id), params) + .await; } ClientRequest::CollaborationModeList { request_id, params } => { let outgoing = self.outgoing.clone(); let thread_manager = self.thread_manager.clone(); + let request_id = to_connection_request_id(request_id); tokio::spawn(async move { Self::list_collaboration_modes(outgoing, thread_manager, request_id, params) @@ -590,109 +625,136 @@ impl CodexMessageProcessor { }); } ClientRequest::MockExperimentalMethod { request_id, params } => { - self.mock_experimental_method(request_id, params).await; + self.mock_experimental_method(to_connection_request_id(request_id), params) + .await; } ClientRequest::McpServerOauthLogin { request_id, params } => { - self.mcp_server_oauth_login(request_id, params).await; + self.mcp_server_oauth_login(to_connection_request_id(request_id), params) + .await; } ClientRequest::McpServerRefresh { request_id, params } => { - self.mcp_server_refresh(request_id, params).await; + self.mcp_server_refresh(to_connection_request_id(request_id), params) + .await; } ClientRequest::McpServerStatusList { request_id, params } => { - self.list_mcp_server_status(request_id, params).await; + self.list_mcp_server_status(to_connection_request_id(request_id), params) + .await; } ClientRequest::LoginAccount { request_id, params } => { - self.login_v2(request_id, params).await; + self.login_v2(to_connection_request_id(request_id), params) + .await; } ClientRequest::LogoutAccount { request_id, params: _, } => { - self.logout_v2(request_id).await; + self.logout_v2(to_connection_request_id(request_id)).await; } ClientRequest::CancelLoginAccount { request_id, params } => { - self.cancel_login_v2(request_id, params).await; + self.cancel_login_v2(to_connection_request_id(request_id), params) + .await; } ClientRequest::GetAccount { request_id, params } => { - self.get_account(request_id, params).await; + self.get_account(to_connection_request_id(request_id), params) + .await; } ClientRequest::ResumeConversation { request_id, params } => { - self.handle_resume_conversation(request_id, params).await; + self.handle_resume_conversation(to_connection_request_id(request_id), params) + .await; } ClientRequest::ForkConversation { request_id, params } => { - self.handle_fork_conversation(request_id, params).await; + self.handle_fork_conversation(to_connection_request_id(request_id), params) + .await; } ClientRequest::ArchiveConversation { request_id, params } => { - self.archive_conversation(request_id, params).await; + self.archive_conversation(to_connection_request_id(request_id), params) + .await; } ClientRequest::SendUserMessage { request_id, params } => { - self.send_user_message(request_id, params).await; + self.send_user_message(to_connection_request_id(request_id), params) + .await; } ClientRequest::SendUserTurn { request_id, params } => { - self.send_user_turn(request_id, params).await; + self.send_user_turn(to_connection_request_id(request_id), params) + .await; } ClientRequest::InterruptConversation { request_id, params } => { - self.interrupt_conversation(request_id, params).await; + self.interrupt_conversation(to_connection_request_id(request_id), params) + .await; } ClientRequest::AddConversationListener { request_id, params } => { - self.add_conversation_listener(request_id, params).await; + self.add_conversation_listener(to_connection_request_id(request_id), params) + .await; } ClientRequest::RemoveConversationListener { request_id, params } => { - self.remove_thread_listener(request_id, params).await; + self.remove_thread_listener(to_connection_request_id(request_id), params) + .await; } ClientRequest::GitDiffToRemote { request_id, params } => { - self.git_diff_to_origin(request_id, params.cwd).await; + self.git_diff_to_origin(to_connection_request_id(request_id), params.cwd) + .await; } ClientRequest::LoginApiKey { request_id, params } => { - self.login_api_key_v1(request_id, params).await; + self.login_api_key_v1(to_connection_request_id(request_id), params) + .await; } ClientRequest::LoginChatGpt { request_id, params: _, } => { - self.login_chatgpt_v1(request_id).await; + self.login_chatgpt_v1(to_connection_request_id(request_id)) + .await; } ClientRequest::CancelLoginChatGpt { request_id, params } => { - self.cancel_login_chatgpt(request_id, params.login_id).await; + self.cancel_login_chatgpt(to_connection_request_id(request_id), params.login_id) + .await; } ClientRequest::LogoutChatGpt { request_id, params: _, } => { - self.logout_v1(request_id).await; + self.logout_v1(to_connection_request_id(request_id)).await; } ClientRequest::GetAuthStatus { request_id, params } => { - self.get_auth_status(request_id, params).await; + self.get_auth_status(to_connection_request_id(request_id), params) + .await; } ClientRequest::GetUserSavedConfig { request_id, params: _, } => { - self.get_user_saved_config(request_id).await; + self.get_user_saved_config(to_connection_request_id(request_id)) + .await; } ClientRequest::SetDefaultModel { request_id, params } => { - self.set_default_model(request_id, params).await; + self.set_default_model(to_connection_request_id(request_id), params) + .await; } ClientRequest::GetUserAgent { request_id, params: _, } => { - self.get_user_agent(request_id).await; + self.get_user_agent(to_connection_request_id(request_id)) + .await; } ClientRequest::UserInfo { request_id, params: _, } => { - self.get_user_info(request_id).await; + self.get_user_info(to_connection_request_id(request_id)) + .await; } ClientRequest::FuzzyFileSearch { request_id, params } => { - self.fuzzy_file_search(request_id, params).await; + self.fuzzy_file_search(to_connection_request_id(request_id), params) + .await; } ClientRequest::OneOffCommandExec { request_id, params } => { - self.exec_one_off_command(request_id, params).await; + self.exec_one_off_command(to_connection_request_id(request_id), params) + .await; } ClientRequest::ExecOneOffCommand { request_id, params } => { - self.exec_one_off_command(request_id, params.into()).await; + self.exec_one_off_command(to_connection_request_id(request_id), params.into()) + .await; } ClientRequest::ConfigRead { .. } | ClientRequest::ConfigValueWrite { .. } @@ -706,15 +768,17 @@ impl CodexMessageProcessor { request_id, params: _, } => { - self.get_account_rate_limits(request_id).await; + self.get_account_rate_limits(to_connection_request_id(request_id)) + .await; } ClientRequest::FeedbackUpload { request_id, params } => { - self.upload_feedback(request_id, params).await; + self.upload_feedback(to_connection_request_id(request_id), params) + .await; } } } - async fn login_v2(&mut self, request_id: RequestId, params: LoginAccountParams) { + async fn login_v2(&mut self, request_id: ConnectionRequestId, params: LoginAccountParams) { match params { LoginAccountParams::ApiKey { api_key } => { self.login_api_key_v2(request_id, LoginApiKeyParams { api_key }) @@ -792,7 +856,11 @@ impl CodexMessageProcessor { } } - async fn login_api_key_v1(&mut self, request_id: RequestId, params: LoginApiKeyParams) { + async fn login_api_key_v1( + &mut self, + request_id: ConnectionRequestId, + params: LoginApiKeyParams, + ) { match self.login_api_key_common(¶ms).await { Ok(()) => { self.outgoing @@ -816,7 +884,11 @@ impl CodexMessageProcessor { } } - async fn login_api_key_v2(&mut self, request_id: RequestId, params: LoginApiKeyParams) { + async fn login_api_key_v2( + &mut self, + request_id: ConnectionRequestId, + params: LoginApiKeyParams, + ) { match self.login_api_key_common(¶ms).await { Ok(()) => { let response = codex_app_server_protocol::LoginAccountResponse::ApiKey {}; @@ -880,7 +952,7 @@ impl CodexMessageProcessor { } // Deprecated in favor of login_chatgpt_v2. - async fn login_chatgpt_v1(&mut self, request_id: RequestId) { + async fn login_chatgpt_v1(&mut self, request_id: ConnectionRequestId) { match self.login_chatgpt_common().await { Ok(opts) => match run_login_server(opts) { Ok(server) => { @@ -986,7 +1058,7 @@ impl CodexMessageProcessor { } } - async fn login_chatgpt_v2(&mut self, request_id: RequestId) { + async fn login_chatgpt_v2(&mut self, request_id: ConnectionRequestId) { match self.login_chatgpt_common().await { Ok(opts) => match run_login_server(opts) { Ok(server) => { @@ -1110,7 +1182,7 @@ impl CodexMessageProcessor { } } - async fn cancel_login_chatgpt(&mut self, request_id: RequestId, login_id: Uuid) { + async fn cancel_login_chatgpt(&mut self, request_id: ConnectionRequestId, login_id: Uuid) { match self.cancel_login_chatgpt_common(login_id).await { Ok(()) => { self.outgoing @@ -1128,7 +1200,11 @@ impl CodexMessageProcessor { } } - async fn cancel_login_v2(&mut self, request_id: RequestId, params: CancelLoginAccountParams) { + async fn cancel_login_v2( + &mut self, + request_id: ConnectionRequestId, + params: CancelLoginAccountParams, + ) { let login_id = params.login_id; match Uuid::parse_str(&login_id) { Ok(uuid) => { @@ -1152,7 +1228,7 @@ impl CodexMessageProcessor { async fn login_chatgpt_auth_tokens( &mut self, - request_id: RequestId, + request_id: ConnectionRequestId, access_token: String, chatgpt_account_id: String, chatgpt_plan_type: Option, @@ -1267,7 +1343,7 @@ impl CodexMessageProcessor { .map(CodexAuth::api_auth_mode)) } - async fn logout_v1(&mut self, request_id: RequestId) { + async fn logout_v1(&mut self, request_id: ConnectionRequestId) { match self.logout_common().await { Ok(current_auth_method) => { self.outgoing @@ -1287,7 +1363,7 @@ impl CodexMessageProcessor { } } - async fn logout_v2(&mut self, request_id: RequestId) { + async fn logout_v2(&mut self, request_id: ConnectionRequestId) { match self.logout_common().await { Ok(current_auth_method) => { self.outgoing @@ -1316,7 +1392,7 @@ impl CodexMessageProcessor { } } - async fn get_auth_status(&self, request_id: RequestId, params: GetAuthStatusParams) { + async fn get_auth_status(&self, request_id: ConnectionRequestId, params: GetAuthStatusParams) { let include_token = params.include_token.unwrap_or(false); let do_refresh = params.refresh_token.unwrap_or(false); @@ -1365,7 +1441,7 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn get_account(&self, request_id: RequestId, params: GetAccountParams) { + async fn get_account(&self, request_id: ConnectionRequestId, params: GetAccountParams) { let do_refresh = params.refresh_token; self.refresh_token_if_requested(do_refresh).await; @@ -1417,13 +1493,13 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn get_user_agent(&self, request_id: RequestId) { + async fn get_user_agent(&self, request_id: ConnectionRequestId) { let user_agent = get_codex_user_agent(); let response = GetUserAgentResponse { user_agent }; self.outgoing.send_response(request_id, response).await; } - async fn get_account_rate_limits(&self, request_id: RequestId) { + async fn get_account_rate_limits(&self, request_id: ConnectionRequestId) { match self.fetch_account_rate_limits().await { Ok(rate_limits) => { let response = GetAccountRateLimitsResponse { @@ -1471,7 +1547,7 @@ impl CodexMessageProcessor { }) } - async fn get_user_saved_config(&self, request_id: RequestId) { + async fn get_user_saved_config(&self, request_id: ConnectionRequestId) { let service = ConfigService::new_with_defaults(self.config.codex_home.clone()); let user_saved_config: UserSavedConfig = match service.load_user_saved_config().await { Ok(config) => config, @@ -1492,7 +1568,7 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn get_user_info(&self, request_id: RequestId) { + async fn get_user_info(&self, request_id: ConnectionRequestId) { // Read alleged user email from cached auth (best-effort; not verified). let alleged_user_email = self .auth_manager @@ -1503,7 +1579,11 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn set_default_model(&self, request_id: RequestId, params: SetDefaultModelParams) { + async fn set_default_model( + &self, + request_id: ConnectionRequestId, + params: SetDefaultModelParams, + ) { let SetDefaultModelParams { model, reasoning_effort, @@ -1530,16 +1610,22 @@ impl CodexMessageProcessor { } } - async fn exec_one_off_command(&self, request_id: RequestId, params: CommandExecParams) { + async fn exec_one_off_command( + &self, + request_id: ConnectionRequestId, + params: CommandExecParams, + ) { tracing::debug!("ExecOneOffCommand params: {params:?}"); + let request = request_id.clone(); + if params.command.is_empty() { let error = JSONRPCErrorError { code: INVALID_REQUEST_ERROR_CODE, message: "command must not be empty".to_string(), data: None, }; - self.outgoing.send_error(request_id, error).await; + self.outgoing.send_error(request, error).await; return; } @@ -1557,7 +1643,7 @@ impl CodexMessageProcessor { message: format!("failed to start managed network proxy: {err}"), data: None, }; - self.outgoing.send_error(request_id, error).await; + self.outgoing.send_error(request, error).await; return; } }, @@ -1588,7 +1674,7 @@ impl CodexMessageProcessor { message: format!("invalid sandbox policy: {err}"), data: None, }; - self.outgoing.send_error(request_id, error).await; + self.outgoing.send_error(request, error).await; return; } }, @@ -1597,7 +1683,7 @@ impl CodexMessageProcessor { let codex_linux_sandbox_exe = self.config.codex_linux_sandbox_exe.clone(); let outgoing = self.outgoing.clone(); - let req_id = request_id; + let request_for_task = request; let sandbox_cwd = self.config.cwd.clone(); let started_network_proxy_for_task = started_network_proxy; let use_linux_sandbox_bwrap = self.config.features.enabled(Feature::UseLinuxSandboxBwrap); @@ -1620,7 +1706,7 @@ impl CodexMessageProcessor { stdout: output.stdout.text, stderr: output.stderr.text, }; - outgoing.send_response(req_id, response).await; + outgoing.send_response(request_for_task, response).await; } Err(err) => { let error = JSONRPCErrorError { @@ -1628,7 +1714,7 @@ impl CodexMessageProcessor { message: format!("exec failed: {err}"), data: None, }; - outgoing.send_error(req_id, error).await; + outgoing.send_error(request_for_task, error).await; } } }); @@ -1636,7 +1722,7 @@ impl CodexMessageProcessor { async fn process_new_conversation( &mut self, - request_id: RequestId, + request_id: ConnectionRequestId, params: NewConversationParams, ) { let NewConversationParams { @@ -1737,7 +1823,7 @@ impl CodexMessageProcessor { } } - async fn thread_start(&mut self, request_id: RequestId, params: ThreadStartParams) { + async fn thread_start(&mut self, request_id: ConnectionRequestId, params: ThreadStartParams) { let ThreadStartParams { model, model_provider, @@ -1902,7 +1988,11 @@ impl CodexMessageProcessor { } } - async fn thread_archive(&mut self, request_id: RequestId, params: ThreadArchiveParams) { + async fn thread_archive( + &mut self, + request_id: ConnectionRequestId, + params: ThreadArchiveParams, + ) { // TODO(jif) mostly rewrite this using sqlite after phase 1 let thread_id = match ThreadId::from_string(¶ms.thread_id) { Ok(id) => id, @@ -1952,7 +2042,7 @@ impl CodexMessageProcessor { } } - async fn thread_set_name(&self, request_id: RequestId, params: ThreadSetNameParams) { + async fn thread_set_name(&self, request_id: ConnectionRequestId, params: ThreadSetNameParams) { let ThreadSetNameParams { thread_id, name } = params; let Some(name) = codex_core::util::normalize_thread_name(&name) else { self.send_invalid_request_error( @@ -1982,7 +2072,11 @@ impl CodexMessageProcessor { .await; } - async fn thread_unarchive(&mut self, request_id: RequestId, params: ThreadUnarchiveParams) { + async fn thread_unarchive( + &mut self, + request_id: ConnectionRequestId, + params: ThreadUnarchiveParams, + ) { // TODO(jif) mostly rewrite this using sqlite after phase 1 let thread_id = match ThreadId::from_string(¶ms.thread_id) { Ok(id) => id, @@ -2155,7 +2249,11 @@ impl CodexMessageProcessor { } } - async fn thread_rollback(&mut self, request_id: RequestId, params: ThreadRollbackParams) { + async fn thread_rollback( + &mut self, + request_id: ConnectionRequestId, + params: ThreadRollbackParams, + ) { let ThreadRollbackParams { thread_id, num_turns, @@ -2175,18 +2273,20 @@ impl CodexMessageProcessor { } }; + let request = request_id.clone(); + { let mut map = self.pending_rollbacks.lock().await; if map.contains_key(&thread_id) { self.send_invalid_request_error( - request_id, + request.clone(), "rollback already in progress for this thread".to_string(), ) .await; return; } - map.insert(thread_id, request_id.clone()); + map.insert(thread_id, request.clone()); } if let Err(err) = thread.submit(Op::ThreadRollback { num_turns }).await { @@ -2195,12 +2295,16 @@ impl CodexMessageProcessor { let mut map = self.pending_rollbacks.lock().await; map.remove(&thread_id); - self.send_internal_error(request_id, format!("failed to start rollback: {err}")) + self.send_internal_error(request, format!("failed to start rollback: {err}")) .await; } } - async fn thread_compact_start(&self, request_id: RequestId, params: ThreadCompactStartParams) { + async fn thread_compact_start( + &self, + request_id: ConnectionRequestId, + params: ThreadCompactStartParams, + ) { let ThreadCompactStartParams { thread_id } = params; let (_, thread) = match self.load_thread(&thread_id).await { @@ -2226,7 +2330,7 @@ impl CodexMessageProcessor { async fn thread_background_terminals_clean( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ThreadBackgroundTerminalsCleanParams, ) { let ThreadBackgroundTerminalsCleanParams { thread_id } = params; @@ -2255,7 +2359,7 @@ impl CodexMessageProcessor { } } - async fn thread_list(&self, request_id: RequestId, params: ThreadListParams) { + async fn thread_list(&self, request_id: ConnectionRequestId, params: ThreadListParams) { let ThreadListParams { cursor, limit, @@ -2296,7 +2400,11 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn thread_loaded_list(&self, request_id: RequestId, params: ThreadLoadedListParams) { + async fn thread_loaded_list( + &self, + request_id: ConnectionRequestId, + params: ThreadLoadedListParams, + ) { let ThreadLoadedListParams { cursor, limit } = params; let mut data = self .thread_manager @@ -2351,7 +2459,7 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn thread_read(&mut self, request_id: RequestId, params: ThreadReadParams) { + async fn thread_read(&mut self, request_id: ConnectionRequestId, params: ThreadReadParams) { let ThreadReadParams { thread_id, include_turns, @@ -2502,7 +2610,7 @@ impl CodexMessageProcessor { } } - async fn thread_resume(&mut self, request_id: RequestId, params: ThreadResumeParams) { + async fn thread_resume(&mut self, request_id: ConnectionRequestId, params: ThreadResumeParams) { let ThreadResumeParams { thread_id, history, @@ -2710,7 +2818,7 @@ impl CodexMessageProcessor { } } - async fn thread_fork(&mut self, request_id: RequestId, params: ThreadForkParams) { + async fn thread_fork(&mut self, request_id: ConnectionRequestId, params: ThreadForkParams) { let ThreadForkParams { thread_id, path, @@ -2915,7 +3023,7 @@ impl CodexMessageProcessor { async fn get_thread_summary( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: GetConversationSummaryParams, ) { if let GetConversationSummaryParams::ThreadId { conversation_id } = ¶ms @@ -2981,7 +3089,7 @@ impl CodexMessageProcessor { async fn handle_list_conversations( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ListConversationsParams, ) { let ListConversationsParams { @@ -3145,7 +3253,7 @@ impl CodexMessageProcessor { outgoing: Arc, thread_manager: Arc, config: Arc, - request_id: RequestId, + request_id: ConnectionRequestId, params: ModelListParams, ) { let ModelListParams { limit, cursor } = params; @@ -3208,7 +3316,7 @@ impl CodexMessageProcessor { async fn list_collaboration_modes( outgoing: Arc, thread_manager: Arc, - request_id: RequestId, + request_id: ConnectionRequestId, params: CollaborationModeListParams, ) { let CollaborationModeListParams {} = params; @@ -3219,7 +3327,7 @@ impl CodexMessageProcessor { async fn experimental_feature_list( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ExperimentalFeatureListParams, ) { let ExperimentalFeatureListParams { cursor, limit } = params; @@ -3329,7 +3437,7 @@ impl CodexMessageProcessor { async fn mock_experimental_method( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: MockExperimentalMethodParams, ) { let MockExperimentalMethodParams { value } = params; @@ -3337,7 +3445,7 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn mcp_server_refresh(&self, request_id: RequestId, _params: Option<()>) { + async fn mcp_server_refresh(&self, request_id: ConnectionRequestId, _params: Option<()>) { let config = match self.load_latest_config().await { Ok(config) => config, Err(error) => { @@ -3390,7 +3498,7 @@ impl CodexMessageProcessor { async fn mcp_server_oauth_login( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: McpServerOauthLoginParams, ) { let config = match self.load_latest_config().await { @@ -3487,26 +3595,28 @@ impl CodexMessageProcessor { async fn list_mcp_server_status( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ListMcpServerStatusParams, ) { + let request = request_id.clone(); + let outgoing = Arc::clone(&self.outgoing); let config = match self.load_latest_config().await { Ok(config) => config, Err(error) => { - self.outgoing.send_error(request_id, error).await; + self.outgoing.send_error(request, error).await; return; } }; tokio::spawn(async move { - Self::list_mcp_server_status_task(outgoing, request_id, params, config).await; + Self::list_mcp_server_status_task(outgoing, request, params, config).await; }); } async fn list_mcp_server_status_task( outgoing: Arc, - request_id: RequestId, + request_id: ConnectionRequestId, params: ListMcpServerStatusParams, config: Config, ) { @@ -3589,7 +3699,7 @@ impl CodexMessageProcessor { async fn handle_resume_conversation( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ResumeConversationParams, ) { let ResumeConversationParams { @@ -3797,7 +3907,7 @@ impl CodexMessageProcessor { async fn handle_fork_conversation( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ForkConversationParams, ) { let ForkConversationParams { @@ -3993,7 +4103,7 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn send_invalid_request_error(&self, request_id: RequestId, message: String) { + async fn send_invalid_request_error(&self, request_id: ConnectionRequestId, message: String) { let error = JSONRPCErrorError { code: INVALID_REQUEST_ERROR_CODE, message, @@ -4002,7 +4112,7 @@ impl CodexMessageProcessor { self.outgoing.send_error(request_id, error).await; } - async fn send_internal_error(&self, request_id: RequestId, message: String) { + async fn send_internal_error(&self, request_id: ConnectionRequestId, message: String) { let error = JSONRPCErrorError { code: INTERNAL_ERROR_CODE, message, @@ -4013,7 +4123,7 @@ impl CodexMessageProcessor { async fn archive_conversation( &mut self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ArchiveConversationParams, ) { let ArchiveConversationParams { @@ -4158,7 +4268,11 @@ impl CodexMessageProcessor { }) } - async fn send_user_message(&self, request_id: RequestId, params: SendUserMessageParams) { + async fn send_user_message( + &self, + request_id: ConnectionRequestId, + params: SendUserMessageParams, + ) { let SendUserMessageParams { conversation_id, items, @@ -4202,7 +4316,7 @@ impl CodexMessageProcessor { .await; } - async fn send_user_turn(&self, request_id: RequestId, params: SendUserTurnParams) { + async fn send_user_turn(&self, request_id: ConnectionRequestId, params: SendUserTurnParams) { let SendUserTurnParams { conversation_id, items, @@ -4260,7 +4374,7 @@ impl CodexMessageProcessor { .await; } - async fn apps_list(&self, request_id: RequestId, params: AppsListParams) { + async fn apps_list(&self, request_id: ConnectionRequestId, params: AppsListParams) { let mut config = match self.load_latest_config().await { Ok(config) => config, Err(error) => { @@ -4307,7 +4421,7 @@ impl CodexMessageProcessor { async fn apps_list_task( outgoing: Arc, - request_id: RequestId, + request_id: ConnectionRequestId, params: AppsListParams, config: Config, ) { @@ -4478,7 +4592,7 @@ impl CodexMessageProcessor { .await; } - async fn skills_list(&self, request_id: RequestId, params: SkillsListParams) { + async fn skills_list(&self, request_id: ConnectionRequestId, params: SkillsListParams) { let SkillsListParams { cwds, force_reload, @@ -4544,7 +4658,11 @@ impl CodexMessageProcessor { .await; } - async fn skills_remote_read(&self, request_id: RequestId, _params: SkillsRemoteReadParams) { + async fn skills_remote_read( + &self, + request_id: ConnectionRequestId, + _params: SkillsRemoteReadParams, + ) { match list_remote_skills(&self.config).await { Ok(skills) => { let data = skills @@ -4569,7 +4687,11 @@ impl CodexMessageProcessor { } } - async fn skills_remote_write(&self, request_id: RequestId, params: SkillsRemoteWriteParams) { + async fn skills_remote_write( + &self, + request_id: ConnectionRequestId, + params: SkillsRemoteWriteParams, + ) { let SkillsRemoteWriteParams { hazelnut_id, is_preload, @@ -4599,7 +4721,11 @@ impl CodexMessageProcessor { } } - async fn skills_config_write(&self, request_id: RequestId, params: SkillsConfigWriteParams) { + async fn skills_config_write( + &self, + request_id: ConnectionRequestId, + params: SkillsConfigWriteParams, + ) { let SkillsConfigWriteParams { path, enabled } = params; let edits = vec![ConfigEdit::SetSkillConfig { path, enabled }]; let result = ConfigEditsBuilder::new(&self.config.codex_home) @@ -4632,7 +4758,7 @@ impl CodexMessageProcessor { async fn interrupt_conversation( &mut self, - request_id: RequestId, + request_id: ConnectionRequestId, params: InterruptConversationParams, ) { let InterruptConversationParams { conversation_id } = params; @@ -4646,19 +4772,21 @@ impl CodexMessageProcessor { return; }; + let request = request_id.clone(); + // Record the pending interrupt so we can reply when TurnAborted arrives. { let mut map = self.pending_interrupts.lock().await; map.entry(conversation_id) .or_default() - .push((request_id, ApiVersion::V1)); + .push((request, ApiVersion::V1)); } // Submit the interrupt; we'll respond upon TurnAborted. let _ = conversation.submit(Op::Interrupt).await; } - async fn turn_start(&self, request_id: RequestId, params: TurnStartParams) { + async fn turn_start(&self, request_id: ConnectionRequestId, params: TurnStartParams) { let (_, thread) = match self.load_thread(¶ms.thread_id).await { Ok(v) => v, Err(error) => { @@ -4744,7 +4872,7 @@ impl CodexMessageProcessor { } } - async fn turn_steer(&self, request_id: RequestId, params: TurnSteerParams) { + async fn turn_steer(&self, request_id: ConnectionRequestId, params: TurnSteerParams) { let (_, thread) = match self.load_thread(¶ms.thread_id).await { Ok(v) => v, Err(error) => { @@ -4825,7 +4953,7 @@ impl CodexMessageProcessor { async fn emit_review_started( &self, - request_id: &RequestId, + request_id: &ConnectionRequestId, turn: Turn, parent_thread_id: String, review_thread_id: String, @@ -4849,7 +4977,7 @@ impl CodexMessageProcessor { async fn start_inline_review( &self, - request_id: &RequestId, + request_id: &ConnectionRequestId, parent_thread: Arc, review_request: ReviewRequest, display_text: &str, @@ -4879,7 +5007,7 @@ impl CodexMessageProcessor { async fn start_detached_review( &mut self, - request_id: &RequestId, + request_id: &ConnectionRequestId, parent_thread_id: ThreadId, review_request: ReviewRequest, display_text: &str, @@ -4971,7 +5099,7 @@ impl CodexMessageProcessor { Ok(()) } - async fn review_start(&mut self, request_id: RequestId, params: ReviewStartParams) { + async fn review_start(&mut self, request_id: ConnectionRequestId, params: ReviewStartParams) { let ReviewStartParams { thread_id, target, @@ -5025,7 +5153,11 @@ impl CodexMessageProcessor { } } - async fn turn_interrupt(&mut self, request_id: RequestId, params: TurnInterruptParams) { + async fn turn_interrupt( + &mut self, + request_id: ConnectionRequestId, + params: TurnInterruptParams, + ) { let TurnInterruptParams { thread_id, .. } = params; let (thread_uuid, thread) = match self.load_thread(&thread_id).await { @@ -5036,12 +5168,14 @@ impl CodexMessageProcessor { } }; + let request = request_id.clone(); + // Record the pending interrupt so we can reply when TurnAborted arrives. { let mut map = self.pending_interrupts.lock().await; map.entry(thread_uuid) .or_default() - .push((request_id, ApiVersion::V2)); + .push((request, ApiVersion::V2)); } // Submit the interrupt; we'll respond upon TurnAborted. @@ -5050,7 +5184,7 @@ impl CodexMessageProcessor { async fn add_conversation_listener( &mut self, - request_id: RequestId, + request_id: ConnectionRequestId, params: AddConversationListenerParams, ) { let AddConversationListenerParams { @@ -5073,7 +5207,7 @@ impl CodexMessageProcessor { async fn remove_thread_listener( &mut self, - request_id: RequestId, + request_id: ConnectionRequestId, params: RemoveConversationListenerParams, ) { let RemoveConversationListenerParams { subscription_id } = params; @@ -5203,7 +5337,7 @@ impl CodexMessageProcessor { Ok(subscription_id) } - async fn git_diff_to_origin(&self, request_id: RequestId, cwd: PathBuf) { + async fn git_diff_to_origin(&self, request_id: ConnectionRequestId, cwd: PathBuf) { let diff = git_diff_to_remote(&cwd).await; match diff { Some(value) => { @@ -5224,7 +5358,11 @@ impl CodexMessageProcessor { } } - async fn fuzzy_file_search(&mut self, request_id: RequestId, params: FuzzyFileSearchParams) { + async fn fuzzy_file_search( + &mut self, + request_id: ConnectionRequestId, + params: FuzzyFileSearchParams, + ) { let FuzzyFileSearchParams { query, roots, @@ -5264,7 +5402,7 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn upload_feedback(&self, request_id: RequestId, params: FeedbackUploadParams) { + async fn upload_feedback(&self, request_id: ConnectionRequestId, params: FeedbackUploadParams) { if !self.config.feedback_enabled { let error = JSONRPCErrorError { code: INVALID_REQUEST_ERROR_CODE, diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 1b940c70d81..ad049ad3055 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -8,14 +8,24 @@ use codex_core::config::ConfigBuilder; use codex_core::config_loader::CloudRequirementsLoader; use codex_core::config_loader::ConfigLayerStackOrdering; use codex_core::config_loader::LoaderOverrides; +use std::collections::HashMap; use std::io::ErrorKind; use std::io::Result as IoResult; use std::path::PathBuf; +use std::sync::Arc; use crate::message_processor::MessageProcessor; use crate::message_processor::MessageProcessorArgs; -use crate::outgoing_message::OutgoingMessage; +use crate::outgoing_message::ConnectionId; +use crate::outgoing_message::OutgoingEnvelope; use crate::outgoing_message::OutgoingMessageSender; +use crate::transport::CHANNEL_CAPACITY; +use crate::transport::ConnectionState; +use crate::transport::TransportEvent; +use crate::transport::has_initialized_connections; +use crate::transport::route_outgoing_envelope; +use crate::transport::start_stdio_connection; +use crate::transport::start_websocket_acceptor; use codex_app_server_protocol::ConfigLayerSource; use codex_app_server_protocol::ConfigWarningNotification; use codex_app_server_protocol::JSONRPCMessage; @@ -26,13 +36,9 @@ use codex_core::check_execpolicy_for_warnings; use codex_core::config_loader::ConfigLoadError; use codex_core::config_loader::TextRange as CoreTextRange; use codex_feedback::CodexFeedback; -use tokio::io::AsyncBufReadExt; -use tokio::io::AsyncWriteExt; -use tokio::io::BufReader; -use tokio::io::{self}; use tokio::sync::mpsc; +use tokio::task::JoinHandle; use toml::Value as TomlValue; -use tracing::debug; use tracing::error; use tracing::info; use tracing::warn; @@ -51,11 +57,9 @@ mod fuzzy_file_search; mod message_processor; mod models; mod outgoing_message; +mod transport; -/// Size of the bounded channels used to communicate between tasks. The value -/// is a balance between throughput and memory usage – 128 messages should be -/// plenty for an interactive CLI. -const CHANNEL_CAPACITY: usize = 128; +pub use crate::transport::AppServerTransport; fn config_warning_from_error( summary: impl Into, @@ -173,32 +177,39 @@ pub async fn run_main( loader_overrides: LoaderOverrides, default_analytics_enabled: bool, ) -> IoResult<()> { - // Set up channels. - let (incoming_tx, mut incoming_rx) = mpsc::channel::(CHANNEL_CAPACITY); - let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY); - - // Task: read from stdin, push to `incoming_tx`. - let stdin_reader_handle = tokio::spawn({ - async move { - let stdin = io::stdin(); - let reader = BufReader::new(stdin); - let mut lines = reader.lines(); - - while let Some(line) = lines.next_line().await.unwrap_or_default() { - match serde_json::from_str::(&line) { - Ok(msg) => { - if incoming_tx.send(msg).await.is_err() { - // Receiver gone – nothing left to do. - break; - } - } - Err(e) => error!("Failed to deserialize JSONRPCMessage: {e}"), - } - } + run_main_with_transport( + codex_linux_sandbox_exe, + cli_config_overrides, + loader_overrides, + default_analytics_enabled, + AppServerTransport::Stdio, + ) + .await +} - debug!("stdin reader finished (EOF)"); +pub async fn run_main_with_transport( + codex_linux_sandbox_exe: Option, + cli_config_overrides: CliConfigOverrides, + loader_overrides: LoaderOverrides, + default_analytics_enabled: bool, + transport: AppServerTransport, +) -> IoResult<()> { + let (transport_event_tx, mut transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY); + + let mut stdio_handles = Vec::>::new(); + let mut websocket_accept_handle = None; + match transport { + AppServerTransport::Stdio => { + start_stdio_connection(transport_event_tx.clone(), &mut stdio_handles).await?; } - }); + AppServerTransport::WebSocket { bind_address } => { + websocket_accept_handle = + Some(start_websocket_acceptor(bind_address, transport_event_tx.clone()).await?); + } + } + let shutdown_when_no_connections = matches!(transport, AppServerTransport::Stdio); // Parse CLI overrides once and derive the base Config eagerly so later // components do not need to work with raw TOML values. @@ -325,15 +336,14 @@ pub async fn run_main( } } - // Task: process incoming messages. let processor_handle = tokio::spawn({ - let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx); + let outgoing_message_sender = Arc::new(OutgoingMessageSender::new(outgoing_tx)); let cli_overrides: Vec<(String, TomlValue)> = cli_kv_overrides.clone(); let loader_overrides = loader_overrides_for_config_api; let mut processor = MessageProcessor::new(MessageProcessorArgs { outgoing: outgoing_message_sender, codex_linux_sandbox_exe, - config: std::sync::Arc::new(config), + config: Arc::new(config), cli_overrides, loader_overrides, cloud_requirements: cloud_requirements.clone(), @@ -341,25 +351,65 @@ pub async fn run_main( config_warnings, }); let mut thread_created_rx = processor.thread_created_receiver(); + let mut connections = HashMap::::new(); async move { let mut listen_for_threads = true; loop { tokio::select! { - msg = incoming_rx.recv() => { - let Some(msg) = msg else { + event = transport_event_rx.recv() => { + let Some(event) = event else { break; }; - match msg { - JSONRPCMessage::Request(r) => processor.process_request(r).await, - JSONRPCMessage::Response(r) => processor.process_response(r).await, - JSONRPCMessage::Notification(n) => processor.process_notification(n).await, - JSONRPCMessage::Error(e) => processor.process_error(e).await, + match event { + TransportEvent::ConnectionOpened { connection_id, writer } => { + connections.insert(connection_id, ConnectionState::new(writer)); + } + TransportEvent::ConnectionClosed { connection_id } => { + connections.remove(&connection_id); + if shutdown_when_no_connections && connections.is_empty() { + break; + } + } + TransportEvent::IncomingMessage { connection_id, message } => { + match message { + JSONRPCMessage::Request(request) => { + let Some(connection_state) = connections.get_mut(&connection_id) else { + warn!("dropping request from unknown connection: {:?}", connection_id); + continue; + }; + processor + .process_request( + connection_id, + request, + &mut connection_state.session, + ) + .await; + } + JSONRPCMessage::Response(response) => { + processor.process_response(response).await; + } + JSONRPCMessage::Notification(notification) => { + processor.process_notification(notification).await; + } + JSONRPCMessage::Error(err) => { + processor.process_error(err).await; + } + } + } } } + envelope = outgoing_rx.recv() => { + let Some(envelope) = envelope else { + break; + }; + route_outgoing_envelope(&mut connections, envelope).await; + } created = thread_created_rx.recv(), if listen_for_threads => { match created { Ok(thread_id) => { - processor.try_attach_thread_listener(thread_id).await; + if has_initialized_connections(&connections) { + processor.try_attach_thread_listener(thread_id).await; + } } Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => { // TODO(jif) handle lag. @@ -380,33 +430,17 @@ pub async fn run_main( } }); - // Task: write outgoing messages to stdout. - let stdout_writer_handle = tokio::spawn(async move { - let mut stdout = io::stdout(); - while let Some(outgoing_message) = outgoing_rx.recv().await { - let Ok(value) = serde_json::to_value(outgoing_message) else { - error!("Failed to convert OutgoingMessage to JSON value"); - continue; - }; - match serde_json::to_string(&value) { - Ok(mut json) => { - json.push('\n'); - if let Err(e) = stdout.write_all(json.as_bytes()).await { - error!("Failed to write to stdout: {e}"); - break; - } - } - Err(e) => error!("Failed to serialize JSONRPCMessage: {e}"), - } - } + drop(transport_event_tx); - info!("stdout writer exited (channel closed)"); - }); + let _ = processor_handle.await; + + if let Some(handle) = websocket_accept_handle { + handle.abort(); + } - // Wait for all tasks to finish. The typical exit path is the stdin reader - // hitting EOF which, once it drops `incoming_tx`, propagates shutdown to - // the processor and then to the stdout task. - let _ = tokio::join!(stdin_reader_handle, processor_handle, stdout_writer_handle); + for handle in stdio_handles { + let _ = handle.await; + } Ok(()) } diff --git a/codex-rs/app-server/src/main.rs b/codex-rs/app-server/src/main.rs index 71d6dc338c2..40dec1dc80c 100644 --- a/codex-rs/app-server/src/main.rs +++ b/codex-rs/app-server/src/main.rs @@ -1,4 +1,6 @@ -use codex_app_server::run_main; +use clap::Parser; +use codex_app_server::AppServerTransport; +use codex_app_server::run_main_with_transport; use codex_arg0::arg0_dispatch_or_else; use codex_common::CliConfigOverrides; use codex_core::config_loader::LoaderOverrides; @@ -8,19 +10,34 @@ use std::path::PathBuf; // managed config file without writing to /etc. const MANAGED_CONFIG_PATH_ENV_VAR: &str = "CODEX_APP_SERVER_MANAGED_CONFIG_PATH"; +#[derive(Debug, Parser)] +struct AppServerArgs { + /// Transport endpoint URL. Supported values: `stdio://` (default), + /// `ws://IP:PORT`. + #[arg( + long = "listen", + value_name = "URL", + default_value = AppServerTransport::DEFAULT_LISTEN_URL + )] + listen: AppServerTransport, +} + fn main() -> anyhow::Result<()> { arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move { + let args = AppServerArgs::parse(); let managed_config_path = managed_config_path_from_debug_env(); let loader_overrides = LoaderOverrides { managed_config_path, ..Default::default() }; + let transport = args.listen; - run_main( + run_main_with_transport( codex_linux_sandbox_exe, CliConfigOverrides::default(), loader_overrides, false, + transport, ) .await?; Ok(()) diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 368d24e7270..2d8e18a46ef 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -1,13 +1,13 @@ use std::path::PathBuf; use std::sync::Arc; use std::sync::RwLock; -use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering; use crate::codex_message_processor::CodexMessageProcessor; use crate::codex_message_processor::CodexMessageProcessorArgs; use crate::config_api::ConfigApi; use crate::error_code::INVALID_REQUEST_ERROR_CODE; +use crate::outgoing_message::ConnectionId; +use crate::outgoing_message::ConnectionRequestId; use crate::outgoing_message::OutgoingMessageSender; use async_trait::async_trait; use codex_app_server_protocol::ChatgptAuthTokensRefreshParams; @@ -26,7 +26,6 @@ use codex_app_server_protocol::JSONRPCErrorError; use codex_app_server_protocol::JSONRPCNotification; use codex_app_server_protocol::JSONRPCRequest; use codex_app_server_protocol::JSONRPCResponse; -use codex_app_server_protocol::RequestId; use codex_app_server_protocol::ServerNotification; use codex_app_server_protocol::ServerRequestPayload; use codex_app_server_protocol::experimental_required_message; @@ -112,13 +111,17 @@ pub(crate) struct MessageProcessor { codex_message_processor: CodexMessageProcessor, config_api: ConfigApi, config: Arc, - initialized: bool, - experimental_api_enabled: Arc, - config_warnings: Vec, + config_warnings: Arc>, +} + +#[derive(Debug, Default)] +pub(crate) struct ConnectionSessionState { + pub(crate) initialized: bool, + experimental_api_enabled: bool, } pub(crate) struct MessageProcessorArgs { - pub(crate) outgoing: OutgoingMessageSender, + pub(crate) outgoing: Arc, pub(crate) codex_linux_sandbox_exe: Option, pub(crate) config: Arc, pub(crate) cli_overrides: Vec<(String, TomlValue)>, @@ -142,8 +145,6 @@ impl MessageProcessor { feedback, config_warnings, } = args; - let outgoing = Arc::new(outgoing); - let experimental_api_enabled = Arc::new(AtomicBool::new(false)); let auth_manager = AuthManager::shared( config.codex_home.clone(), false, @@ -181,14 +182,20 @@ impl MessageProcessor { codex_message_processor, config_api, config, - initialized: false, - experimental_api_enabled, - config_warnings, + config_warnings: Arc::new(config_warnings), } } - pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) { - let request_id = request.id.clone(); + pub(crate) async fn process_request( + &mut self, + connection_id: ConnectionId, + request: JSONRPCRequest, + session: &mut ConnectionSessionState, + ) { + let request_id = ConnectionRequestId { + connection_id, + request_id: request.id.clone(), + }; let request_json = match serde_json::to_value(&request) { Ok(request_json) => request_json, Err(err) => { @@ -219,7 +226,11 @@ impl MessageProcessor { // Handle Initialize internally so CodexMessageProcessor does not have to concern // itself with the `initialized` bool. ClientRequest::Initialize { request_id, params } => { - if self.initialized { + let request_id = ConnectionRequestId { + connection_id, + request_id, + }; + if session.initialized { let error = JSONRPCErrorError { code: INVALID_REQUEST_ERROR_CODE, message: "Already initialized".to_string(), @@ -228,6 +239,12 @@ impl MessageProcessor { self.outgoing.send_error(request_id, error).await; return; } else { + // TODO(maxj): Revisit capability scoping for `experimental_api_enabled`. + // Current behavior is per-connection. Reviewer feedback notes this can + // create odd cross-client behavior (for example dynamic tool calls on a + // shared thread when another connected client did not opt into + // experimental API). Proposed direction is instance-global first-write-wins + // with initialize-time mismatch rejection. let (experimental_api_enabled, opt_out_notification_methods) = match params.capabilities { Some(capabilities) => ( @@ -238,8 +255,7 @@ impl MessageProcessor { ), None => (false, Vec::new()), }; - self.experimental_api_enabled - .store(experimental_api_enabled, Ordering::Relaxed); + session.experimental_api_enabled = experimental_api_enabled; self.outgoing .set_opted_out_notification_methods(opt_out_notification_methods) .await; @@ -258,7 +274,7 @@ impl MessageProcessor { ), data: None, }; - self.outgoing.send_error(request_id, error).await; + self.outgoing.send_error(request_id.clone(), error).await; return; } SetOriginatorError::AlreadyInitialized => { @@ -279,22 +295,20 @@ impl MessageProcessor { let response = InitializeResponse { user_agent }; self.outgoing.send_response(request_id, response).await; - self.initialized = true; - if !self.config_warnings.is_empty() { - for notification in self.config_warnings.drain(..) { - self.outgoing - .send_server_notification(ServerNotification::ConfigWarning( - notification, - )) - .await; - } + session.initialized = true; + for notification in self.config_warnings.iter().cloned() { + self.outgoing + .send_server_notification(ServerNotification::ConfigWarning( + notification, + )) + .await; } return; } } _ => { - if !self.initialized { + if !session.initialized { let error = JSONRPCErrorError { code: INVALID_REQUEST_ERROR_CODE, message: "Not initialized".to_string(), @@ -307,7 +321,7 @@ impl MessageProcessor { } if let Some(reason) = codex_request.experimental_reason() - && !self.experimental_api_enabled.load(Ordering::Relaxed) + && !session.experimental_api_enabled { let error = JSONRPCErrorError { code: INVALID_REQUEST_ERROR_CODE, @@ -320,22 +334,49 @@ impl MessageProcessor { match codex_request { ClientRequest::ConfigRead { request_id, params } => { - self.handle_config_read(request_id, params).await; + self.handle_config_read( + ConnectionRequestId { + connection_id, + request_id, + }, + params, + ) + .await; } ClientRequest::ConfigValueWrite { request_id, params } => { - self.handle_config_value_write(request_id, params).await; + self.handle_config_value_write( + ConnectionRequestId { + connection_id, + request_id, + }, + params, + ) + .await; } ClientRequest::ConfigBatchWrite { request_id, params } => { - self.handle_config_batch_write(request_id, params).await; + self.handle_config_batch_write( + ConnectionRequestId { + connection_id, + request_id, + }, + params, + ) + .await; } ClientRequest::ConfigRequirementsRead { request_id, params: _, } => { - self.handle_config_requirements_read(request_id).await; + self.handle_config_requirements_read(ConnectionRequestId { + connection_id, + request_id, + }) + .await; } other => { - self.codex_message_processor.process_request(other).await; + self.codex_message_processor + .process_request(connection_id, other) + .await; } } } @@ -351,9 +392,6 @@ impl MessageProcessor { } pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) { - if !self.initialized { - return; - } self.codex_message_processor .try_attach_thread_listener(thread_id) .await; @@ -372,7 +410,7 @@ impl MessageProcessor { self.outgoing.notify_client_error(err.id, err.error).await; } - async fn handle_config_read(&self, request_id: RequestId, params: ConfigReadParams) { + async fn handle_config_read(&self, request_id: ConnectionRequestId, params: ConfigReadParams) { match self.config_api.read(params).await { Ok(response) => self.outgoing.send_response(request_id, response).await, Err(error) => self.outgoing.send_error(request_id, error).await, @@ -381,7 +419,7 @@ impl MessageProcessor { async fn handle_config_value_write( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ConfigValueWriteParams, ) { match self.config_api.write_value(params).await { @@ -392,7 +430,7 @@ impl MessageProcessor { async fn handle_config_batch_write( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ConfigBatchWriteParams, ) { match self.config_api.batch_write(params).await { @@ -401,7 +439,7 @@ impl MessageProcessor { } } - async fn handle_config_requirements_read(&self, request_id: RequestId) { + async fn handle_config_requirements_read(&self, request_id: ConnectionRequestId) { match self.config_api.config_requirements_read().await { Ok(response) => self.outgoing.send_response(request_id, response).await, Err(error) => self.outgoing.send_error(request_id, error).await, diff --git a/codex-rs/app-server/src/outgoing_message.rs b/codex-rs/app-server/src/outgoing_message.rs index b64bd5bce95..d58ed8543a3 100644 --- a/codex-rs/app-server/src/outgoing_message.rs +++ b/codex-rs/app-server/src/outgoing_message.rs @@ -20,18 +20,40 @@ use crate::error_code::INTERNAL_ERROR_CODE; #[cfg(test)] use codex_protocol::account::PlanType; +/// Stable identifier for a transport connection. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub(crate) struct ConnectionId(pub(crate) u64); + +/// Stable identifier for a client request scoped to a transport connection. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub(crate) struct ConnectionRequestId { + pub(crate) connection_id: ConnectionId, + pub(crate) request_id: RequestId, +} + +#[derive(Debug, Clone)] +pub(crate) enum OutgoingEnvelope { + ToConnection { + connection_id: ConnectionId, + message: OutgoingMessage, + }, + Broadcast { + message: OutgoingMessage, + }, +} + /// Sends messages to the client and manages request callbacks. pub(crate) struct OutgoingMessageSender { - next_request_id: AtomicI64, - sender: mpsc::Sender, + next_server_request_id: AtomicI64, + sender: mpsc::Sender, request_id_to_callback: Mutex>>, opted_out_notification_methods: Mutex>, } impl OutgoingMessageSender { - pub(crate) fn new(sender: mpsc::Sender) -> Self { + pub(crate) fn new(sender: mpsc::Sender) -> Self { Self { - next_request_id: AtomicI64::new(0), + next_server_request_id: AtomicI64::new(0), sender, request_id_to_callback: Mutex::new(HashMap::new()), opted_out_notification_methods: Mutex::new(HashSet::new()), @@ -61,7 +83,7 @@ impl OutgoingMessageSender { &self, request: ServerRequestPayload, ) -> (RequestId, oneshot::Receiver) { - let id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::Relaxed)); + let id = RequestId::Integer(self.next_server_request_id.fetch_add(1, Ordering::Relaxed)); let outgoing_message_id = id.clone(); let (tx_approve, rx_approve) = oneshot::channel(); { @@ -71,7 +93,13 @@ impl OutgoingMessageSender { let outgoing_message = OutgoingMessage::Request(request.request_with_id(outgoing_message_id.clone())); - if let Err(err) = self.sender.send(outgoing_message).await { + if let Err(err) = self + .sender + .send(OutgoingEnvelope::Broadcast { + message: outgoing_message, + }) + .await + { warn!("failed to send request {outgoing_message_id:?} to client: {err:?}"); let mut request_id_to_callback = self.request_id_to_callback.lock().await; request_id_to_callback.remove(&outgoing_message_id); @@ -121,17 +149,31 @@ impl OutgoingMessageSender { entry.is_some() } - pub(crate) async fn send_response(&self, id: RequestId, response: T) { + pub(crate) async fn send_response( + &self, + request_id: ConnectionRequestId, + response: T, + ) { match serde_json::to_value(response) { Ok(result) => { - let outgoing_message = OutgoingMessage::Response(OutgoingResponse { id, result }); - if let Err(err) = self.sender.send(outgoing_message).await { + let outgoing_message = OutgoingMessage::Response(OutgoingResponse { + id: request_id.request_id, + result, + }); + if let Err(err) = self + .sender + .send(OutgoingEnvelope::ToConnection { + connection_id: request_id.connection_id, + message: outgoing_message, + }) + .await + { warn!("failed to send response to client: {err:?}"); } } Err(err) => { self.send_error( - id, + request_id, JSONRPCErrorError { code: INTERNAL_ERROR_CODE, message: format!("failed to serialize response: {err}"), @@ -150,7 +192,9 @@ impl OutgoingMessageSender { } if let Err(err) = self .sender - .send(OutgoingMessage::AppServerNotification(notification)) + .send(OutgoingEnvelope::Broadcast { + message: OutgoingMessage::AppServerNotification(notification), + }) .await { warn!("failed to send server notification to client: {err:?}"); @@ -167,14 +211,34 @@ impl OutgoingMessageSender { return; } let outgoing_message = OutgoingMessage::Notification(notification); - if let Err(err) = self.sender.send(outgoing_message).await { + if let Err(err) = self + .sender + .send(OutgoingEnvelope::Broadcast { + message: outgoing_message, + }) + .await + { warn!("failed to send notification to client: {err:?}"); } } - pub(crate) async fn send_error(&self, id: RequestId, error: JSONRPCErrorError) { - let outgoing_message = OutgoingMessage::Error(OutgoingError { id, error }); - if let Err(err) = self.sender.send(outgoing_message).await { + pub(crate) async fn send_error( + &self, + request_id: ConnectionRequestId, + error: JSONRPCErrorError, + ) { + let outgoing_message = OutgoingMessage::Error(OutgoingError { + id: request_id.request_id, + error, + }); + if let Err(err) = self + .sender + .send(OutgoingEnvelope::ToConnection { + connection_id: request_id.connection_id, + message: outgoing_message, + }) + .await + { warn!("failed to send error to client: {err:?}"); } } @@ -214,6 +278,8 @@ pub(crate) struct OutgoingError { #[cfg(test)] mod tests { + use std::time::Duration; + use codex_app_server_protocol::AccountLoginCompletedNotification; use codex_app_server_protocol::AccountRateLimitsUpdatedNotification; use codex_app_server_protocol::AccountUpdatedNotification; @@ -224,6 +290,7 @@ mod tests { use codex_app_server_protocol::RateLimitWindow; use pretty_assertions::assert_eq; use serde_json::json; + use tokio::time::timeout; use uuid::Uuid; use super::*; @@ -360,4 +427,75 @@ mod tests { "ensure the notification serializes correctly" ); } + + #[tokio::test] + async fn send_response_routes_to_target_connection() { + let (tx, mut rx) = mpsc::channel::(4); + let outgoing = OutgoingMessageSender::new(tx); + let request_id = ConnectionRequestId { + connection_id: ConnectionId(42), + request_id: RequestId::Integer(7), + }; + + outgoing + .send_response(request_id.clone(), json!({ "ok": true })) + .await; + + let envelope = timeout(Duration::from_secs(1), rx.recv()) + .await + .expect("should receive envelope before timeout") + .expect("channel should contain one message"); + + match envelope { + OutgoingEnvelope::ToConnection { + connection_id, + message, + } => { + assert_eq!(connection_id, ConnectionId(42)); + let OutgoingMessage::Response(response) = message else { + panic!("expected response message"); + }; + assert_eq!(response.id, request_id.request_id); + assert_eq!(response.result, json!({ "ok": true })); + } + other => panic!("expected targeted response envelope, got: {other:?}"), + } + } + + #[tokio::test] + async fn send_error_routes_to_target_connection() { + let (tx, mut rx) = mpsc::channel::(4); + let outgoing = OutgoingMessageSender::new(tx); + let request_id = ConnectionRequestId { + connection_id: ConnectionId(9), + request_id: RequestId::Integer(3), + }; + let error = JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message: "boom".to_string(), + data: None, + }; + + outgoing.send_error(request_id.clone(), error.clone()).await; + + let envelope = timeout(Duration::from_secs(1), rx.recv()) + .await + .expect("should receive envelope before timeout") + .expect("channel should contain one message"); + + match envelope { + OutgoingEnvelope::ToConnection { + connection_id, + message, + } => { + assert_eq!(connection_id, ConnectionId(9)); + let OutgoingMessage::Error(outgoing_error) = message else { + panic!("expected error message"); + }; + assert_eq!(outgoing_error.id, RequestId::Integer(3)); + assert_eq!(outgoing_error.error, error); + } + other => panic!("expected targeted error envelope, got: {other:?}"), + } + } } diff --git a/codex-rs/app-server/src/transport.rs b/codex-rs/app-server/src/transport.rs new file mode 100644 index 00000000000..39fd13212cf --- /dev/null +++ b/codex-rs/app-server/src/transport.rs @@ -0,0 +1,459 @@ +use crate::message_processor::ConnectionSessionState; +use crate::outgoing_message::ConnectionId; +use crate::outgoing_message::OutgoingEnvelope; +use crate::outgoing_message::OutgoingMessage; +use codex_app_server_protocol::JSONRPCMessage; +use futures::SinkExt; +use futures::StreamExt; +use owo_colors::OwoColorize; +use owo_colors::Stream; +use owo_colors::Style; +use std::collections::HashMap; +use std::io::ErrorKind; +use std::io::Result as IoResult; +use std::net::SocketAddr; +use std::str::FromStr; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::sync::atomic::Ordering; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tokio::io::{self}; +use tokio::net::TcpListener; +use tokio::net::TcpStream; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; +use tokio_tungstenite::accept_async; +use tokio_tungstenite::tungstenite::Message as WebSocketMessage; +use tracing::debug; +use tracing::error; +use tracing::info; +use tracing::warn; + +/// Size of the bounded channels used to communicate between tasks. The value +/// is a balance between throughput and memory usage - 128 messages should be +/// plenty for an interactive CLI. +pub(crate) const CHANNEL_CAPACITY: usize = 128; + +fn colorize(text: &str, style: Style) -> String { + text.if_supports_color(Stream::Stderr, |value| value.style(style)) + .to_string() +} + +#[allow(clippy::print_stderr)] +fn print_websocket_startup_banner(addr: SocketAddr) { + let title = colorize("codex app-server (WebSockets)", Style::new().bold().cyan()); + let listening_label = colorize("listening on:", Style::new().dimmed()); + let listen_url = colorize(&format!("ws://{addr}"), Style::new().green()); + let note_label = colorize("note:", Style::new().dimmed()); + eprintln!("{title}"); + eprintln!(" {listening_label} {listen_url}"); + if addr.ip().is_loopback() { + eprintln!( + " {note_label} binds localhost only (use SSH port-forwarding for remote access)" + ); + } else { + eprintln!( + " {note_label} this is a raw WS server; consider running behind TLS/auth for real remote use" + ); + } +} + +#[allow(clippy::print_stderr)] +fn print_websocket_connection(peer_addr: SocketAddr) { + let connected_label = colorize("websocket client connected from", Style::new().dimmed()); + eprintln!("{connected_label} {peer_addr}"); +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum AppServerTransport { + Stdio, + WebSocket { bind_address: SocketAddr }, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum AppServerTransportParseError { + UnsupportedListenUrl(String), + InvalidWebSocketListenUrl(String), +} + +impl std::fmt::Display for AppServerTransportParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AppServerTransportParseError::UnsupportedListenUrl(listen_url) => write!( + f, + "unsupported --listen URL `{listen_url}`; expected `stdio://` or `ws://IP:PORT`" + ), + AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!( + f, + "invalid websocket --listen URL `{listen_url}`; expected `ws://IP:PORT`" + ), + } + } +} + +impl std::error::Error for AppServerTransportParseError {} + +impl AppServerTransport { + pub const DEFAULT_LISTEN_URL: &'static str = "stdio://"; + + pub fn from_listen_url(listen_url: &str) -> Result { + if listen_url == Self::DEFAULT_LISTEN_URL { + return Ok(Self::Stdio); + } + + if let Some(socket_addr) = listen_url.strip_prefix("ws://") { + let bind_address = socket_addr.parse::().map_err(|_| { + AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string()) + })?; + return Ok(Self::WebSocket { bind_address }); + } + + Err(AppServerTransportParseError::UnsupportedListenUrl( + listen_url.to_string(), + )) + } +} + +impl FromStr for AppServerTransport { + type Err = AppServerTransportParseError; + + fn from_str(s: &str) -> Result { + Self::from_listen_url(s) + } +} + +#[derive(Debug)] +pub(crate) enum TransportEvent { + ConnectionOpened { + connection_id: ConnectionId, + writer: mpsc::Sender, + }, + ConnectionClosed { + connection_id: ConnectionId, + }, + IncomingMessage { + connection_id: ConnectionId, + message: JSONRPCMessage, + }, +} + +pub(crate) struct ConnectionState { + pub(crate) writer: mpsc::Sender, + pub(crate) session: ConnectionSessionState, +} + +impl ConnectionState { + pub(crate) fn new(writer: mpsc::Sender) -> Self { + Self { + writer, + session: ConnectionSessionState::default(), + } + } +} + +pub(crate) async fn start_stdio_connection( + transport_event_tx: mpsc::Sender, + stdio_handles: &mut Vec>, +) -> IoResult<()> { + let connection_id = ConnectionId(0); + let (writer_tx, mut writer_rx) = mpsc::channel::(CHANNEL_CAPACITY); + transport_event_tx + .send(TransportEvent::ConnectionOpened { + connection_id, + writer: writer_tx, + }) + .await + .map_err(|_| std::io::Error::new(ErrorKind::BrokenPipe, "processor unavailable"))?; + + let transport_event_tx_for_reader = transport_event_tx.clone(); + stdio_handles.push(tokio::spawn(async move { + let stdin = io::stdin(); + let reader = BufReader::new(stdin); + let mut lines = reader.lines(); + + loop { + match lines.next_line().await { + Ok(Some(line)) => { + if !forward_incoming_message( + &transport_event_tx_for_reader, + connection_id, + &line, + ) + .await + { + break; + } + } + Ok(None) => break, + Err(err) => { + error!("Failed reading stdin: {err}"); + break; + } + } + } + + let _ = transport_event_tx_for_reader + .send(TransportEvent::ConnectionClosed { connection_id }) + .await; + debug!("stdin reader finished (EOF)"); + })); + + stdio_handles.push(tokio::spawn(async move { + let mut stdout = io::stdout(); + while let Some(outgoing_message) = writer_rx.recv().await { + let Some(mut json) = serialize_outgoing_message(outgoing_message) else { + continue; + }; + json.push('\n'); + if let Err(err) = stdout.write_all(json.as_bytes()).await { + error!("Failed to write to stdout: {err}"); + break; + } + } + info!("stdout writer exited (channel closed)"); + })); + + Ok(()) +} + +pub(crate) async fn start_websocket_acceptor( + bind_address: SocketAddr, + transport_event_tx: mpsc::Sender, +) -> IoResult> { + let listener = TcpListener::bind(bind_address).await?; + let local_addr = listener.local_addr()?; + print_websocket_startup_banner(local_addr); + info!("app-server websocket listening on ws://{local_addr}"); + + let connection_counter = Arc::new(AtomicU64::new(1)); + Ok(tokio::spawn(async move { + loop { + match listener.accept().await { + Ok((stream, peer_addr)) => { + print_websocket_connection(peer_addr); + let connection_id = + ConnectionId(connection_counter.fetch_add(1, Ordering::Relaxed)); + let transport_event_tx_for_connection = transport_event_tx.clone(); + tokio::spawn(async move { + run_websocket_connection( + connection_id, + stream, + transport_event_tx_for_connection, + ) + .await; + }); + } + Err(err) => { + error!("failed to accept websocket connection: {err}"); + } + } + } + })) +} + +async fn run_websocket_connection( + connection_id: ConnectionId, + stream: TcpStream, + transport_event_tx: mpsc::Sender, +) { + let websocket_stream = match accept_async(stream).await { + Ok(stream) => stream, + Err(err) => { + warn!("failed to complete websocket handshake: {err}"); + return; + } + }; + + let (writer_tx, mut writer_rx) = mpsc::channel::(CHANNEL_CAPACITY); + if transport_event_tx + .send(TransportEvent::ConnectionOpened { + connection_id, + writer: writer_tx, + }) + .await + .is_err() + { + return; + } + + let (mut websocket_writer, mut websocket_reader) = websocket_stream.split(); + loop { + tokio::select! { + outgoing_message = writer_rx.recv() => { + let Some(outgoing_message) = outgoing_message else { + break; + }; + let Some(json) = serialize_outgoing_message(outgoing_message) else { + continue; + }; + if websocket_writer.send(WebSocketMessage::Text(json.into())).await.is_err() { + break; + } + } + incoming_message = websocket_reader.next() => { + match incoming_message { + Some(Ok(WebSocketMessage::Text(text))) => { + if !forward_incoming_message(&transport_event_tx, connection_id, &text).await { + break; + } + } + Some(Ok(WebSocketMessage::Ping(payload))) => { + if websocket_writer.send(WebSocketMessage::Pong(payload)).await.is_err() { + break; + } + } + Some(Ok(WebSocketMessage::Pong(_))) => {} + Some(Ok(WebSocketMessage::Close(_))) | None => break, + Some(Ok(WebSocketMessage::Binary(_))) => { + warn!("dropping unsupported binary websocket message"); + } + Some(Ok(WebSocketMessage::Frame(_))) => {} + Some(Err(err)) => { + warn!("websocket receive error: {err}"); + break; + } + } + } + } + } + + let _ = transport_event_tx + .send(TransportEvent::ConnectionClosed { connection_id }) + .await; +} + +async fn forward_incoming_message( + transport_event_tx: &mpsc::Sender, + connection_id: ConnectionId, + payload: &str, +) -> bool { + match serde_json::from_str::(payload) { + Ok(message) => transport_event_tx + .send(TransportEvent::IncomingMessage { + connection_id, + message, + }) + .await + .is_ok(), + Err(err) => { + error!("Failed to deserialize JSONRPCMessage: {err}"); + true + } + } +} + +fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option { + let value = match serde_json::to_value(outgoing_message) { + Ok(value) => value, + Err(err) => { + error!("Failed to convert OutgoingMessage to JSON value: {err}"); + return None; + } + }; + match serde_json::to_string(&value) { + Ok(json) => Some(json), + Err(err) => { + error!("Failed to serialize JSONRPCMessage: {err}"); + None + } + } +} + +pub(crate) async fn route_outgoing_envelope( + connections: &mut HashMap, + envelope: OutgoingEnvelope, +) { + match envelope { + OutgoingEnvelope::ToConnection { + connection_id, + message, + } => { + let Some(connection_state) = connections.get(&connection_id) else { + warn!( + "dropping message for disconnected connection: {:?}", + connection_id + ); + return; + }; + if connection_state.writer.send(message).await.is_err() { + connections.remove(&connection_id); + } + } + OutgoingEnvelope::Broadcast { message } => { + let target_connections: Vec = connections + .iter() + .filter_map(|(connection_id, connection_state)| { + if connection_state.session.initialized { + Some(*connection_id) + } else { + None + } + }) + .collect(); + + for connection_id in target_connections { + let Some(connection_state) = connections.get(&connection_id) else { + continue; + }; + if connection_state.writer.send(message.clone()).await.is_err() { + connections.remove(&connection_id); + } + } + } + } +} + +pub(crate) fn has_initialized_connections( + connections: &HashMap, +) -> bool { + connections + .values() + .any(|connection| connection.session.initialized) +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn app_server_transport_parses_stdio_listen_url() { + let transport = AppServerTransport::from_listen_url(AppServerTransport::DEFAULT_LISTEN_URL) + .expect("stdio listen URL should parse"); + assert_eq!(transport, AppServerTransport::Stdio); + } + + #[test] + fn app_server_transport_parses_websocket_listen_url() { + let transport = AppServerTransport::from_listen_url("ws://127.0.0.1:1234") + .expect("websocket listen URL should parse"); + assert_eq!( + transport, + AppServerTransport::WebSocket { + bind_address: "127.0.0.1:1234".parse().expect("valid socket address"), + } + ); + } + + #[test] + fn app_server_transport_rejects_invalid_websocket_listen_url() { + let err = AppServerTransport::from_listen_url("ws://localhost:1234") + .expect_err("hostname bind address should be rejected"); + assert_eq!( + err.to_string(), + "invalid websocket --listen URL `ws://localhost:1234`; expected `ws://IP:PORT`" + ); + } + + #[test] + fn app_server_transport_rejects_unsupported_listen_url() { + let err = AppServerTransport::from_listen_url("http://127.0.0.1:1234") + .expect_err("unsupported scheme should fail"); + assert_eq!( + err.to_string(), + "unsupported --listen URL `http://127.0.0.1:1234`; expected `stdio://` or `ws://IP:PORT`" + ); + } +} diff --git a/codex-rs/app-server/tests/common/mcp_process.rs b/codex-rs/app-server/tests/common/mcp_process.rs index 7f77d8fc92a..d7ebc19e954 100644 --- a/codex-rs/app-server/tests/common/mcp_process.rs +++ b/codex-rs/app-server/tests/common/mcp_process.rs @@ -174,7 +174,7 @@ impl McpProcess { client_info, Some(InitializeCapabilities { experimental_api: true, - opt_out_notification_methods: None, + ..Default::default() }), ) .await diff --git a/codex-rs/app-server/tests/suite/v2/connection_handling_websocket.rs b/codex-rs/app-server/tests/suite/v2/connection_handling_websocket.rs new file mode 100644 index 00000000000..ddd4326fc99 --- /dev/null +++ b/codex-rs/app-server/tests/suite/v2/connection_handling_websocket.rs @@ -0,0 +1,263 @@ +use anyhow::Context; +use anyhow::Result; +use anyhow::bail; +use app_test_support::create_mock_responses_server_sequence_unchecked; +use codex_app_server_protocol::ClientInfo; +use codex_app_server_protocol::InitializeParams; +use codex_app_server_protocol::JSONRPCError; +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCRequest; +use codex_app_server_protocol::JSONRPCResponse; +use codex_app_server_protocol::RequestId; +use futures::SinkExt; +use futures::StreamExt; +use serde_json::json; +use std::net::SocketAddr; +use std::path::Path; +use std::process::Stdio; +use tempfile::TempDir; +use tokio::io::AsyncBufReadExt; +use tokio::process::Child; +use tokio::process::Command; +use tokio::time::Duration; +use tokio::time::Instant; +use tokio::time::sleep; +use tokio::time::timeout; +use tokio_tungstenite::MaybeTlsStream; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::connect_async; +use tokio_tungstenite::tungstenite::Message as WebSocketMessage; + +const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5); + +type WsClient = WebSocketStream>; + +#[tokio::test] +async fn websocket_transport_routes_per_connection_handshake_and_responses() -> Result<()> { + let server = create_mock_responses_server_sequence_unchecked(Vec::new()).await; + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri(), "never")?; + + let bind_addr = reserve_local_addr()?; + let mut process = spawn_websocket_server(codex_home.path(), bind_addr).await?; + + let mut ws1 = connect_websocket(bind_addr).await?; + let mut ws2 = connect_websocket(bind_addr).await?; + + send_initialize_request(&mut ws1, 1, "ws_client_one").await?; + let first_init = read_response_for_id(&mut ws1, 1).await?; + assert_eq!(first_init.id, RequestId::Integer(1)); + + // Initialize responses are request-scoped and must not leak to other + // connections. + assert_no_message(&mut ws2, Duration::from_millis(250)).await?; + + send_config_read_request(&mut ws2, 2).await?; + let not_initialized = read_error_for_id(&mut ws2, 2).await?; + assert_eq!(not_initialized.error.message, "Not initialized"); + + send_initialize_request(&mut ws2, 3, "ws_client_two").await?; + let second_init = read_response_for_id(&mut ws2, 3).await?; + assert_eq!(second_init.id, RequestId::Integer(3)); + + // Same request-id on different connections must route independently. + send_config_read_request(&mut ws1, 77).await?; + send_config_read_request(&mut ws2, 77).await?; + let ws1_config = read_response_for_id(&mut ws1, 77).await?; + let ws2_config = read_response_for_id(&mut ws2, 77).await?; + + assert_eq!(ws1_config.id, RequestId::Integer(77)); + assert_eq!(ws2_config.id, RequestId::Integer(77)); + assert!(ws1_config.result.get("config").is_some()); + assert!(ws2_config.result.get("config").is_some()); + + process + .kill() + .await + .context("failed to stop websocket app-server process")?; + Ok(()) +} + +async fn spawn_websocket_server(codex_home: &Path, bind_addr: SocketAddr) -> Result { + let program = codex_utils_cargo_bin::cargo_bin("codex-app-server") + .context("should find app-server binary")?; + let mut cmd = Command::new(program); + cmd.arg("--listen") + .arg(format!("ws://{bind_addr}")) + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::piped()) + .env("CODEX_HOME", codex_home) + .env("RUST_LOG", "debug"); + let mut process = cmd + .kill_on_drop(true) + .spawn() + .context("failed to spawn websocket app-server process")?; + + if let Some(stderr) = process.stderr.take() { + let mut stderr_reader = tokio::io::BufReader::new(stderr).lines(); + tokio::spawn(async move { + while let Ok(Some(line)) = stderr_reader.next_line().await { + eprintln!("[websocket app-server stderr] {line}"); + } + }); + } + + Ok(process) +} + +fn reserve_local_addr() -> Result { + let listener = std::net::TcpListener::bind("127.0.0.1:0")?; + let addr = listener.local_addr()?; + drop(listener); + Ok(addr) +} + +async fn connect_websocket(bind_addr: SocketAddr) -> Result { + let url = format!("ws://{bind_addr}"); + let deadline = Instant::now() + Duration::from_secs(10); + loop { + match connect_async(&url).await { + Ok((stream, _response)) => return Ok(stream), + Err(err) => { + if Instant::now() >= deadline { + bail!("failed to connect websocket to {url}: {err}"); + } + sleep(Duration::from_millis(50)).await; + } + } + } +} + +async fn send_initialize_request(stream: &mut WsClient, id: i64, client_name: &str) -> Result<()> { + let params = InitializeParams { + client_info: ClientInfo { + name: client_name.to_string(), + title: Some("WebSocket Test Client".to_string()), + version: "0.1.0".to_string(), + }, + capabilities: None, + }; + send_request( + stream, + "initialize", + id, + Some(serde_json::to_value(params)?), + ) + .await +} + +async fn send_config_read_request(stream: &mut WsClient, id: i64) -> Result<()> { + send_request( + stream, + "config/read", + id, + Some(json!({ "includeLayers": false })), + ) + .await +} + +async fn send_request( + stream: &mut WsClient, + method: &str, + id: i64, + params: Option, +) -> Result<()> { + let message = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(id), + method: method.to_string(), + params, + }); + send_jsonrpc(stream, message).await +} + +async fn send_jsonrpc(stream: &mut WsClient, message: JSONRPCMessage) -> Result<()> { + let payload = serde_json::to_string(&message)?; + stream + .send(WebSocketMessage::Text(payload.into())) + .await + .context("failed to send websocket frame") +} + +async fn read_response_for_id(stream: &mut WsClient, id: i64) -> Result { + let target_id = RequestId::Integer(id); + loop { + let message = read_jsonrpc_message(stream).await?; + if let JSONRPCMessage::Response(response) = message + && response.id == target_id + { + return Ok(response); + } + } +} + +async fn read_error_for_id(stream: &mut WsClient, id: i64) -> Result { + let target_id = RequestId::Integer(id); + loop { + let message = read_jsonrpc_message(stream).await?; + if let JSONRPCMessage::Error(err) = message + && err.id == target_id + { + return Ok(err); + } + } +} + +async fn read_jsonrpc_message(stream: &mut WsClient) -> Result { + loop { + let frame = timeout(DEFAULT_READ_TIMEOUT, stream.next()) + .await + .context("timed out waiting for websocket frame")? + .context("websocket stream ended unexpectedly")? + .context("failed to read websocket frame")?; + + match frame { + WebSocketMessage::Text(text) => return Ok(serde_json::from_str(text.as_ref())?), + WebSocketMessage::Ping(payload) => { + stream.send(WebSocketMessage::Pong(payload)).await?; + } + WebSocketMessage::Pong(_) => {} + WebSocketMessage::Close(frame) => { + bail!("websocket closed unexpectedly: {frame:?}") + } + WebSocketMessage::Binary(_) => bail!("unexpected binary websocket frame"), + WebSocketMessage::Frame(_) => {} + } + } +} + +async fn assert_no_message(stream: &mut WsClient, wait_for: Duration) -> Result<()> { + match timeout(wait_for, stream.next()).await { + Ok(Some(Ok(frame))) => bail!("unexpected frame while waiting for silence: {frame:?}"), + Ok(Some(Err(err))) => bail!("unexpected websocket read error: {err}"), + Ok(None) => bail!("websocket closed unexpectedly while waiting for silence"), + Err(_) => Ok(()), + } +} + +fn create_config_toml( + codex_home: &Path, + server_uri: &str, + approval_policy: &str, +) -> std::io::Result<()> { + let config_toml = codex_home.join("config.toml"); + std::fs::write( + config_toml, + format!( + r#" +model = "mock-model" +approval_policy = "{approval_policy}" +sandbox_mode = "read-only" + +model_provider = "mock_provider" + +[model_providers.mock_provider] +name = "Mock provider for test" +base_url = "{server_uri}/v1" +wire_api = "responses" +request_max_retries = 0 +stream_max_retries = 0 +"# + ), + ) +} diff --git a/codex-rs/app-server/tests/suite/v2/mod.rs b/codex-rs/app-server/tests/suite/v2/mod.rs index e9e19395be4..48622acddbe 100644 --- a/codex-rs/app-server/tests/suite/v2/mod.rs +++ b/codex-rs/app-server/tests/suite/v2/mod.rs @@ -4,6 +4,7 @@ mod app_list; mod collaboration_mode_list; mod compaction; mod config_rpc; +mod connection_handling_websocket; mod dynamic_tools; mod experimental_api; mod experimental_feature_list; diff --git a/codex-rs/app-server/tests/suite/v2/review.rs b/codex-rs/app-server/tests/suite/v2/review.rs index 1250c055ddf..2f919eee381 100644 --- a/codex-rs/app-server/tests/suite/v2/review.rs +++ b/codex-rs/app-server/tests/suite/v2/review.rs @@ -5,8 +5,6 @@ use app_test_support::create_mock_responses_server_repeating_assistant; use app_test_support::create_mock_responses_server_sequence; use app_test_support::create_shell_command_sse_response; use app_test_support::to_response; -use codex_app_server_protocol::CommandExecutionApprovalDecision; -use codex_app_server_protocol::CommandExecutionRequestApprovalResponse; use codex_app_server_protocol::ItemCompletedNotification; use codex_app_server_protocol::ItemStartedNotification; use codex_app_server_protocol::JSONRPCError; @@ -210,9 +208,7 @@ async fn review_start_exec_approval_item_id_matches_command_execution_item() -> mcp.send_response( request_id, - serde_json::to_value(CommandExecutionRequestApprovalResponse { - decision: CommandExecutionApprovalDecision::Accept, - })?, + serde_json::json!({ "decision": codex_core::protocol::ReviewDecision::Approved }), ) .await?; timeout( diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index defc063eb6d..1cc4fcdaa7b 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -306,6 +306,15 @@ struct AppServerCommand { #[command(subcommand)] subcommand: Option, + /// Transport endpoint URL. Supported values: `stdio://` (default), + /// `ws://IP:PORT`. + #[arg( + long = "listen", + value_name = "URL", + default_value = codex_app_server::AppServerTransport::DEFAULT_LISTEN_URL + )] + listen: codex_app_server::AppServerTransport, + /// Controls whether analytics are enabled by default. /// /// Analytics are disabled by default for app-server. Users have to explicitly opt in @@ -587,11 +596,13 @@ async fn cli_main(codex_linux_sandbox_exe: Option) -> anyhow::Result<() } Some(Subcommand::AppServer(app_server_cli)) => match app_server_cli.subcommand { None => { - codex_app_server::run_main( + let transport = app_server_cli.listen; + codex_app_server::run_main_with_transport( codex_linux_sandbox_exe, root_config_overrides, codex_core::config_loader::LoaderOverrides::default(), app_server_cli.analytics_default_enabled, + transport, ) .await?; } @@ -1328,6 +1339,10 @@ mod tests { fn app_server_analytics_default_disabled_without_flag() { let app_server = app_server_from_args(["codex", "app-server"].as_ref()); assert!(!app_server.analytics_default_enabled); + assert_eq!( + app_server.listen, + codex_app_server::AppServerTransport::Stdio + ); } #[test] @@ -1337,6 +1352,36 @@ mod tests { assert!(app_server.analytics_default_enabled); } + #[test] + fn app_server_listen_websocket_url_parses() { + let app_server = app_server_from_args( + ["codex", "app-server", "--listen", "ws://127.0.0.1:4500"].as_ref(), + ); + assert_eq!( + app_server.listen, + codex_app_server::AppServerTransport::WebSocket { + bind_address: "127.0.0.1:4500".parse().expect("valid socket address"), + } + ); + } + + #[test] + fn app_server_listen_stdio_url_parses() { + let app_server = + app_server_from_args(["codex", "app-server", "--listen", "stdio://"].as_ref()); + assert_eq!( + app_server.listen, + codex_app_server::AppServerTransport::Stdio + ); + } + + #[test] + fn app_server_listen_invalid_url_fails_to_parse() { + let parse_result = + MultitoolCli::try_parse_from(["codex", "app-server", "--listen", "http://foo"]); + assert!(parse_result.is_err()); + } + #[test] fn features_enable_parses_feature_name() { let cli = MultitoolCli::try_parse_from(["codex", "features", "enable", "unified_exec"]) diff --git a/codex-rs/core/tests/suite/review.rs b/codex-rs/core/tests/suite/review.rs index a7010ecaf1d..1c9c3adf7b0 100644 --- a/codex-rs/core/tests/suite/review.rs +++ b/codex-rs/core/tests/suite/review.rs @@ -371,25 +371,6 @@ async fn review_does_not_emit_agent_message_on_structured_output() { _ => false, }) .await; - // On slower CI hosts, the final AgentMessage can arrive immediately after - // TurnComplete. Drain a brief tail window to make ordering nondeterminism - // harmless while still enforcing "exactly one final AgentMessage". - while let Ok(Ok(event)) = - tokio::time::timeout(std::time::Duration::from_millis(200), codex.next_event()).await - { - match event.msg { - EventMsg::AgentMessage(_) => agent_messages += 1, - EventMsg::EnteredReviewMode(_) => saw_entered = true, - EventMsg::ExitedReviewMode(_) => saw_exited = true, - EventMsg::AgentMessageContentDelta(_) => { - panic!("unexpected AgentMessageContentDelta surfaced during review") - } - EventMsg::AgentMessageDelta(_) => { - panic!("unexpected AgentMessageDelta surfaced during review") - } - _ => {} - } - } assert_eq!(1, agent_messages, "expected exactly one AgentMessage event"); assert!(saw_entered && saw_exited, "missing review lifecycle events"); From b254b72cb4aea8c098376b3b76b9111efca8fe66 Mon Sep 17 00:00:00 2001 From: Max Johnson Date: Tue, 10 Feb 2026 14:29:35 -0800 Subject: [PATCH 08/16] ws fix --- codex-rs/app-server/src/transport.rs | 112 ++++++++++++++++++++------- 1 file changed, 82 insertions(+), 30 deletions(-) diff --git a/codex-rs/app-server/src/transport.rs b/codex-rs/app-server/src/transport.rs index 39fd13212cf..93d312d3eff 100644 --- a/codex-rs/app-server/src/transport.rs +++ b/codex-rs/app-server/src/transport.rs @@ -279,46 +279,98 @@ async fn run_websocket_connection( } let (mut websocket_writer, mut websocket_reader) = websocket_stream.split(); - loop { - tokio::select! { - outgoing_message = writer_rx.recv() => { - let Some(outgoing_message) = outgoing_message else { - break; - }; - let Some(json) = serialize_outgoing_message(outgoing_message) else { - continue; - }; - if websocket_writer.send(WebSocketMessage::Text(json.into())).await.is_err() { - break; - } + let (websocket_outgoing_tx, mut websocket_outgoing_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + + let mut outgoing_handle = tokio::spawn(async move { + while let Some(outgoing_message) = websocket_outgoing_rx.recv().await { + if websocket_writer.send(outgoing_message).await.is_err() { + break; } - incoming_message = websocket_reader.next() => { - match incoming_message { - Some(Ok(WebSocketMessage::Text(text))) => { - if !forward_incoming_message(&transport_event_tx, connection_id, &text).await { - break; - } + } + }); + + let websocket_outgoing_tx_for_forwarder = websocket_outgoing_tx.clone(); + let mut outgoing_forwarder_handle = tokio::spawn(async move { + while let Some(outgoing_message) = writer_rx.recv().await { + let Some(json) = serialize_outgoing_message(outgoing_message) else { + continue; + }; + if websocket_outgoing_tx_for_forwarder + .send(WebSocketMessage::Text(json.into())) + .await + .is_err() + { + break; + } + } + }); + + let transport_event_tx_for_incoming = transport_event_tx.clone(); + let mut incoming_handle = tokio::spawn(async move { + loop { + match websocket_reader.next().await { + Some(Ok(WebSocketMessage::Text(text))) => { + if !forward_incoming_message( + &transport_event_tx_for_incoming, + connection_id, + &text, + ) + .await + { + break; } - Some(Ok(WebSocketMessage::Ping(payload))) => { - if websocket_writer.send(WebSocketMessage::Pong(payload)).await.is_err() { + } + Some(Ok(WebSocketMessage::Ping(payload))) => { + match websocket_outgoing_tx.try_send(WebSocketMessage::Pong(payload)) { + Ok(()) => {} + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { + debug!("dropping websocket pong because outgoing queue is full"); + } + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { break; } } - Some(Ok(WebSocketMessage::Pong(_))) => {} - Some(Ok(WebSocketMessage::Close(_))) | None => break, - Some(Ok(WebSocketMessage::Binary(_))) => { - warn!("dropping unsupported binary websocket message"); - } - Some(Ok(WebSocketMessage::Frame(_))) => {} - Some(Err(err)) => { - warn!("websocket receive error: {err}"); - break; - } } + Some(Ok(WebSocketMessage::Pong(_))) => {} + Some(Ok(WebSocketMessage::Close(_))) | None => break, + Some(Ok(WebSocketMessage::Binary(_))) => { + warn!("dropping unsupported binary websocket message"); + } + Some(Ok(WebSocketMessage::Frame(_))) => {} + Some(Err(err)) => { + warn!("websocket receive error: {err}"); + break; + } + } + } + }); + + tokio::select! { + join_result = &mut outgoing_handle => { + if let Err(err) = join_result { + warn!("websocket outgoing task failed: {err}"); + } + } + join_result = &mut outgoing_forwarder_handle => { + if let Err(err) = join_result { + warn!("websocket outgoing forwarder task failed: {err}"); + } + } + join_result = &mut incoming_handle => { + if let Err(err) = join_result { + warn!("websocket incoming task failed: {err}"); } } } + outgoing_handle.abort(); + outgoing_forwarder_handle.abort(); + incoming_handle.abort(); + let _ = outgoing_handle.await; + let _ = outgoing_forwarder_handle.await; + let _ = incoming_handle.await; + let _ = transport_event_tx .send(TransportEvent::ConnectionClosed { connection_id }) .await; From 7560a70d4e3c002c7098591249d3b9e3c1d53636 Mon Sep 17 00:00:00 2001 From: Max Johnson Date: Tue, 10 Feb 2026 14:38:34 -0800 Subject: [PATCH 09/16] app-server: split incoming and outgoing loops --- codex-rs/app-server/src/lib.rs | 190 ++++++++++++++++++++++----------- 1 file changed, 129 insertions(+), 61 deletions(-) diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index ad049ad3055..4833eda92be 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -14,6 +14,7 @@ use std::io::Result as IoResult; use std::path::PathBuf; use std::sync::Arc; +use crate::message_processor::ConnectionSessionState; use crate::message_processor::MessageProcessor; use crate::message_processor::MessageProcessorArgs; use crate::outgoing_message::ConnectionId; @@ -36,6 +37,7 @@ use codex_core::check_execpolicy_for_warnings; use codex_core::config_loader::ConfigLoadError; use codex_core::config_loader::TextRange as CoreTextRange; use codex_feedback::CodexFeedback; +use tokio::sync::Mutex; use tokio::sync::mpsc; use tokio::task::JoinHandle; use toml::Value as TomlValue; @@ -336,79 +338,143 @@ pub async fn run_main_with_transport( } } - let processor_handle = tokio::spawn({ - let outgoing_message_sender = Arc::new(OutgoingMessageSender::new(outgoing_tx)); - let cli_overrides: Vec<(String, TomlValue)> = cli_kv_overrides.clone(); - let loader_overrides = loader_overrides_for_config_api; - let mut processor = MessageProcessor::new(MessageProcessorArgs { - outgoing: outgoing_message_sender, - codex_linux_sandbox_exe, - config: Arc::new(config), - cli_overrides, - loader_overrides, - cloud_requirements: cloud_requirements.clone(), - feedback: feedback.clone(), - config_warnings, - }); - let mut thread_created_rx = processor.thread_created_receiver(); - let mut connections = HashMap::::new(); + let outgoing_message_sender = Arc::new(OutgoingMessageSender::new(outgoing_tx)); + let cli_overrides: Vec<(String, TomlValue)> = cli_kv_overrides.clone(); + let loader_overrides = loader_overrides_for_config_api; + let processor = MessageProcessor::new(MessageProcessorArgs { + outgoing: outgoing_message_sender, + codex_linux_sandbox_exe, + config: Arc::new(config), + cli_overrides, + loader_overrides, + cloud_requirements: cloud_requirements.clone(), + feedback: feedback.clone(), + config_warnings, + }); + let mut thread_created_rx = processor.thread_created_receiver(); + let processor = Arc::new(Mutex::new(processor)); + let connections = Arc::new(Mutex::new(HashMap::::new())); + + let incoming_handle = tokio::spawn({ + let processor = Arc::clone(&processor); + let connections = Arc::clone(&connections); async move { - let mut listen_for_threads = true; - loop { - tokio::select! { - event = transport_event_rx.recv() => { - let Some(event) = event else { - break; + let mut sessions = HashMap::::new(); + while let Some(event) = transport_event_rx.recv().await { + match event { + TransportEvent::ConnectionOpened { + connection_id, + writer, + } => { + sessions.insert(connection_id, ConnectionSessionState::default()); + connections + .lock() + .await + .insert(connection_id, ConnectionState::new(writer)); + } + TransportEvent::ConnectionClosed { connection_id } => { + sessions.remove(&connection_id); + let should_shutdown = { + let mut connections = connections.lock().await; + connections.remove(&connection_id); + shutdown_when_no_connections && connections.is_empty() }; - match event { - TransportEvent::ConnectionOpened { connection_id, writer } => { - connections.insert(connection_id, ConnectionState::new(writer)); - } - TransportEvent::ConnectionClosed { connection_id } => { - connections.remove(&connection_id); - if shutdown_when_no_connections && connections.is_empty() { - break; - } + if should_shutdown { + break; + } + } + TransportEvent::IncomingMessage { + connection_id, + message, + } => match message { + JSONRPCMessage::Request(request) => { + let Some(session) = sessions.get_mut(&connection_id) else { + warn!( + "dropping request from unknown connection: {:?}", + connection_id + ); + continue; + }; + let was_initialized = session.initialized; + let pre_synced_initialize = + !session.initialized && request.method == "initialize"; + if pre_synced_initialize + && let Some(connection_state) = + connections.lock().await.get_mut(&connection_id) + { + connection_state.session.initialized = true; } - TransportEvent::IncomingMessage { connection_id, message } => { - match message { - JSONRPCMessage::Request(request) => { - let Some(connection_state) = connections.get_mut(&connection_id) else { - warn!("dropping request from unknown connection: {:?}", connection_id); - continue; - }; - processor - .process_request( - connection_id, - request, - &mut connection_state.session, - ) - .await; - } - JSONRPCMessage::Response(response) => { - processor.process_response(response).await; - } - JSONRPCMessage::Notification(notification) => { - processor.process_notification(notification).await; - } - JSONRPCMessage::Error(err) => { - processor.process_error(err).await; - } - } + + processor + .lock() + .await + .process_request(connection_id, request, session) + .await; + + if pre_synced_initialize + && !session.initialized + && let Some(connection_state) = + connections.lock().await.get_mut(&connection_id) + { + connection_state.session.initialized = false; + } else if session.initialized != was_initialized + && let Some(connection_state) = + connections.lock().await.get_mut(&connection_id) + { + connection_state.session.initialized = session.initialized; } } - } + JSONRPCMessage::Response(response) => { + processor.lock().await.process_response(response).await; + } + JSONRPCMessage::Notification(notification) => { + processor + .lock() + .await + .process_notification(notification) + .await; + } + JSONRPCMessage::Error(err) => { + processor.lock().await.process_error(err).await; + } + }, + } + } + + info!("incoming processor task exited (channel closed)"); + } + }); + + let outgoing_handle = tokio::spawn({ + let processor = Arc::clone(&processor); + let connections = Arc::clone(&connections); + async move { + let mut listen_for_threads = true; + loop { + tokio::select! { envelope = outgoing_rx.recv() => { let Some(envelope) = envelope else { break; }; + let mut connections = connections.lock().await; route_outgoing_envelope(&mut connections, envelope).await; } created = thread_created_rx.recv(), if listen_for_threads => { match created { Ok(thread_id) => { - if has_initialized_connections(&connections) { - processor.try_attach_thread_listener(thread_id).await; + let should_attach = { + let connections = connections.lock().await; + has_initialized_connections(&connections) + }; + if should_attach { + let processor = Arc::clone(&processor); + tokio::spawn(async move { + processor + .lock() + .await + .try_attach_thread_listener(thread_id) + .await; + }); } } Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => { @@ -426,13 +492,15 @@ pub async fn run_main_with_transport( } } - info!("processor task exited (channel closed)"); + info!("outgoing router task exited (channel closed)"); } }); drop(transport_event_tx); - let _ = processor_handle.await; + let _ = incoming_handle.await; + outgoing_handle.abort(); + let _ = outgoing_handle.await; if let Some(handle) = websocket_accept_handle { handle.abort(); From 9b572220597ba23bafb75fbb35a2abacc546a2e2 Mon Sep 17 00:00:00 2001 From: Max Johnson Date: Tue, 10 Feb 2026 15:35:16 -0800 Subject: [PATCH 10/16] fixes for opt-out --- codex-rs/app-server/src/lib.rs | 16 ++---- codex-rs/app-server/src/message_processor.rs | 9 ++-- codex-rs/app-server/src/outgoing_message.rs | 24 --------- codex-rs/app-server/src/transport.rs | 54 +++++++++++++++++--- 4 files changed, 57 insertions(+), 46 deletions(-) diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 4833eda92be..2c824e74162 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -395,7 +395,6 @@ pub async fn run_main_with_transport( ); continue; }; - let was_initialized = session.initialized; let pre_synced_initialize = !session.initialized && request.method == "initialize"; if pre_synced_initialize @@ -411,17 +410,10 @@ pub async fn run_main_with_transport( .process_request(connection_id, request, session) .await; - if pre_synced_initialize - && !session.initialized - && let Some(connection_state) = - connections.lock().await.get_mut(&connection_id) - { - connection_state.session.initialized = false; - } else if session.initialized != was_initialized - && let Some(connection_state) = - connections.lock().await.get_mut(&connection_id) + if let Some(connection_state) = + connections.lock().await.get_mut(&connection_id) { - connection_state.session.initialized = session.initialized; + connection_state.session = session.clone(); } } JSONRPCMessage::Response(response) => { @@ -457,7 +449,7 @@ pub async fn run_main_with_transport( break; }; let mut connections = connections.lock().await; - route_outgoing_envelope(&mut connections, envelope).await; + route_outgoing_envelope(&mut connections, envelope); } created = thread_created_rx.recv(), if listen_for_threads => { match created { diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 2d8e18a46ef..227f4d7d220 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::path::PathBuf; use std::sync::Arc; use std::sync::RwLock; @@ -114,10 +115,11 @@ pub(crate) struct MessageProcessor { config_warnings: Arc>, } -#[derive(Debug, Default)] +#[derive(Clone, Debug, Default)] pub(crate) struct ConnectionSessionState { pub(crate) initialized: bool, experimental_api_enabled: bool, + pub(crate) opted_out_notification_methods: HashSet, } pub(crate) struct MessageProcessorArgs { @@ -256,9 +258,8 @@ impl MessageProcessor { None => (false, Vec::new()), }; session.experimental_api_enabled = experimental_api_enabled; - self.outgoing - .set_opted_out_notification_methods(opt_out_notification_methods) - .await; + session.opted_out_notification_methods = + opt_out_notification_methods.into_iter().collect(); let ClientInfo { name, title: _title, diff --git a/codex-rs/app-server/src/outgoing_message.rs b/codex-rs/app-server/src/outgoing_message.rs index d58ed8543a3..a5219dc2dc8 100644 --- a/codex-rs/app-server/src/outgoing_message.rs +++ b/codex-rs/app-server/src/outgoing_message.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::collections::HashSet; use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; @@ -47,7 +46,6 @@ pub(crate) struct OutgoingMessageSender { next_server_request_id: AtomicI64, sender: mpsc::Sender, request_id_to_callback: Mutex>>, - opted_out_notification_methods: Mutex>, } impl OutgoingMessageSender { @@ -56,21 +54,9 @@ impl OutgoingMessageSender { next_server_request_id: AtomicI64::new(0), sender, request_id_to_callback: Mutex::new(HashMap::new()), - opted_out_notification_methods: Mutex::new(HashSet::new()), } } - pub(crate) async fn set_opted_out_notification_methods(&self, methods: Vec) { - let mut opted_out = self.opted_out_notification_methods.lock().await; - opted_out.clear(); - opted_out.extend(methods); - } - - async fn should_skip_notification(&self, method: &str) -> bool { - let opted_out = self.opted_out_notification_methods.lock().await; - opted_out.contains(method) - } - pub(crate) async fn send_request( &self, request: ServerRequestPayload, @@ -186,10 +172,6 @@ impl OutgoingMessageSender { } pub(crate) async fn send_server_notification(&self, notification: ServerNotification) { - let method = notification.to_string(); - if self.should_skip_notification(&method).await { - return; - } if let Err(err) = self .sender .send(OutgoingEnvelope::Broadcast { @@ -204,12 +186,6 @@ impl OutgoingMessageSender { /// All notifications should be migrated to [`ServerNotification`] and /// [`OutgoingMessage::Notification`] should be removed. pub(crate) async fn send_notification(&self, notification: OutgoingNotification) { - if self - .should_skip_notification(notification.method.as_str()) - .await - { - return; - } let outgoing_message = OutgoingMessage::Notification(notification); if let Err(err) = self .sender diff --git a/codex-rs/app-server/src/transport.rs b/codex-rs/app-server/src/transport.rs index 93d312d3eff..7eecf3a05f8 100644 --- a/codex-rs/app-server/src/transport.rs +++ b/codex-rs/app-server/src/transport.rs @@ -413,7 +413,27 @@ fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option bool { + match message { + OutgoingMessage::AppServerNotification(notification) => { + let method = notification.to_string(); + connection_state + .session + .opted_out_notification_methods + .contains(method.as_str()) + } + OutgoingMessage::Notification(notification) => connection_state + .session + .opted_out_notification_methods + .contains(notification.method.as_str()), + _ => false, + } +} + +pub(crate) fn route_outgoing_envelope( connections: &mut HashMap, envelope: OutgoingEnvelope, ) { @@ -429,15 +449,27 @@ pub(crate) async fn route_outgoing_envelope( ); return; }; - if connection_state.writer.send(message).await.is_err() { - connections.remove(&connection_id); + match connection_state.writer.try_send(message) { + Ok(()) => {} + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { + connections.remove(&connection_id); + } + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { + warn!( + "dropping slow connection with full outgoing queue: {:?}", + connection_id + ); + connections.remove(&connection_id); + } } } OutgoingEnvelope::Broadcast { message } => { let target_connections: Vec = connections .iter() .filter_map(|(connection_id, connection_state)| { - if connection_state.session.initialized { + if connection_state.session.initialized + && !should_skip_notification_for_connection(connection_state, &message) + { Some(*connection_id) } else { None @@ -449,8 +481,18 @@ pub(crate) async fn route_outgoing_envelope( let Some(connection_state) = connections.get(&connection_id) else { continue; }; - if connection_state.writer.send(message.clone()).await.is_err() { - connections.remove(&connection_id); + match connection_state.writer.try_send(message.clone()) { + Ok(()) => {} + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { + connections.remove(&connection_id); + } + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { + warn!( + "dropping slow connection with full outgoing queue: {:?}", + connection_id + ); + connections.remove(&connection_id); + } } } } From a19232bee9fbf43dbad98b40b2a2daae1e9e5ad9 Mon Sep 17 00:00:00 2001 From: Max Johnson Date: Tue, 10 Feb 2026 15:37:57 -0800 Subject: [PATCH 11/16] blocking send --- codex-rs/app-server/src/lib.rs | 2 +- codex-rs/app-server/src/transport.rs | 30 +++++----------------------- 2 files changed, 6 insertions(+), 26 deletions(-) diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 2c824e74162..ecbf04905b7 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -449,7 +449,7 @@ pub async fn run_main_with_transport( break; }; let mut connections = connections.lock().await; - route_outgoing_envelope(&mut connections, envelope); + route_outgoing_envelope(&mut connections, envelope).await; } created = thread_created_rx.recv(), if listen_for_threads => { match created { diff --git a/codex-rs/app-server/src/transport.rs b/codex-rs/app-server/src/transport.rs index 7eecf3a05f8..4fa636acb28 100644 --- a/codex-rs/app-server/src/transport.rs +++ b/codex-rs/app-server/src/transport.rs @@ -433,7 +433,7 @@ fn should_skip_notification_for_connection( } } -pub(crate) fn route_outgoing_envelope( +pub(crate) async fn route_outgoing_envelope( connections: &mut HashMap, envelope: OutgoingEnvelope, ) { @@ -449,18 +449,8 @@ pub(crate) fn route_outgoing_envelope( ); return; }; - match connection_state.writer.try_send(message) { - Ok(()) => {} - Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { - connections.remove(&connection_id); - } - Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { - warn!( - "dropping slow connection with full outgoing queue: {:?}", - connection_id - ); - connections.remove(&connection_id); - } + if connection_state.writer.send(message).await.is_err() { + connections.remove(&connection_id); } } OutgoingEnvelope::Broadcast { message } => { @@ -481,18 +471,8 @@ pub(crate) fn route_outgoing_envelope( let Some(connection_state) = connections.get(&connection_id) else { continue; }; - match connection_state.writer.try_send(message.clone()) { - Ok(()) => {} - Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { - connections.remove(&connection_id); - } - Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { - warn!( - "dropping slow connection with full outgoing queue: {:?}", - connection_id - ); - connections.remove(&connection_id); - } + if connection_state.writer.send(message.clone()).await.is_err() { + connections.remove(&connection_id); } } } From 8eff270c9a2c3d5e74d5215f8eecef2398727a3a Mon Sep 17 00:00:00 2001 From: Max Johnson Date: Wed, 11 Feb 2026 09:31:03 -0800 Subject: [PATCH 12/16] Revert "app-server: split incoming and outgoing loops" This reverts commit 7560a70d4e3c002c7098591249d3b9e3c1d53636. --- codex-rs/app-server/src/lib.rs | 182 +++++++++++---------------------- 1 file changed, 61 insertions(+), 121 deletions(-) diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index ecbf04905b7..ad049ad3055 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -14,7 +14,6 @@ use std::io::Result as IoResult; use std::path::PathBuf; use std::sync::Arc; -use crate::message_processor::ConnectionSessionState; use crate::message_processor::MessageProcessor; use crate::message_processor::MessageProcessorArgs; use crate::outgoing_message::ConnectionId; @@ -37,7 +36,6 @@ use codex_core::check_execpolicy_for_warnings; use codex_core::config_loader::ConfigLoadError; use codex_core::config_loader::TextRange as CoreTextRange; use codex_feedback::CodexFeedback; -use tokio::sync::Mutex; use tokio::sync::mpsc; use tokio::task::JoinHandle; use toml::Value as TomlValue; @@ -338,135 +336,79 @@ pub async fn run_main_with_transport( } } - let outgoing_message_sender = Arc::new(OutgoingMessageSender::new(outgoing_tx)); - let cli_overrides: Vec<(String, TomlValue)> = cli_kv_overrides.clone(); - let loader_overrides = loader_overrides_for_config_api; - let processor = MessageProcessor::new(MessageProcessorArgs { - outgoing: outgoing_message_sender, - codex_linux_sandbox_exe, - config: Arc::new(config), - cli_overrides, - loader_overrides, - cloud_requirements: cloud_requirements.clone(), - feedback: feedback.clone(), - config_warnings, - }); - let mut thread_created_rx = processor.thread_created_receiver(); - let processor = Arc::new(Mutex::new(processor)); - let connections = Arc::new(Mutex::new(HashMap::::new())); - - let incoming_handle = tokio::spawn({ - let processor = Arc::clone(&processor); - let connections = Arc::clone(&connections); + let processor_handle = tokio::spawn({ + let outgoing_message_sender = Arc::new(OutgoingMessageSender::new(outgoing_tx)); + let cli_overrides: Vec<(String, TomlValue)> = cli_kv_overrides.clone(); + let loader_overrides = loader_overrides_for_config_api; + let mut processor = MessageProcessor::new(MessageProcessorArgs { + outgoing: outgoing_message_sender, + codex_linux_sandbox_exe, + config: Arc::new(config), + cli_overrides, + loader_overrides, + cloud_requirements: cloud_requirements.clone(), + feedback: feedback.clone(), + config_warnings, + }); + let mut thread_created_rx = processor.thread_created_receiver(); + let mut connections = HashMap::::new(); async move { - let mut sessions = HashMap::::new(); - while let Some(event) = transport_event_rx.recv().await { - match event { - TransportEvent::ConnectionOpened { - connection_id, - writer, - } => { - sessions.insert(connection_id, ConnectionSessionState::default()); - connections - .lock() - .await - .insert(connection_id, ConnectionState::new(writer)); - } - TransportEvent::ConnectionClosed { connection_id } => { - sessions.remove(&connection_id); - let should_shutdown = { - let mut connections = connections.lock().await; - connections.remove(&connection_id); - shutdown_when_no_connections && connections.is_empty() - }; - if should_shutdown { + let mut listen_for_threads = true; + loop { + tokio::select! { + event = transport_event_rx.recv() => { + let Some(event) = event else { break; - } - } - TransportEvent::IncomingMessage { - connection_id, - message, - } => match message { - JSONRPCMessage::Request(request) => { - let Some(session) = sessions.get_mut(&connection_id) else { - warn!( - "dropping request from unknown connection: {:?}", - connection_id - ); - continue; - }; - let pre_synced_initialize = - !session.initialized && request.method == "initialize"; - if pre_synced_initialize - && let Some(connection_state) = - connections.lock().await.get_mut(&connection_id) - { - connection_state.session.initialized = true; + }; + match event { + TransportEvent::ConnectionOpened { connection_id, writer } => { + connections.insert(connection_id, ConnectionState::new(writer)); } - - processor - .lock() - .await - .process_request(connection_id, request, session) - .await; - - if let Some(connection_state) = - connections.lock().await.get_mut(&connection_id) - { - connection_state.session = session.clone(); + TransportEvent::ConnectionClosed { connection_id } => { + connections.remove(&connection_id); + if shutdown_when_no_connections && connections.is_empty() { + break; + } + } + TransportEvent::IncomingMessage { connection_id, message } => { + match message { + JSONRPCMessage::Request(request) => { + let Some(connection_state) = connections.get_mut(&connection_id) else { + warn!("dropping request from unknown connection: {:?}", connection_id); + continue; + }; + processor + .process_request( + connection_id, + request, + &mut connection_state.session, + ) + .await; + } + JSONRPCMessage::Response(response) => { + processor.process_response(response).await; + } + JSONRPCMessage::Notification(notification) => { + processor.process_notification(notification).await; + } + JSONRPCMessage::Error(err) => { + processor.process_error(err).await; + } + } } } - JSONRPCMessage::Response(response) => { - processor.lock().await.process_response(response).await; - } - JSONRPCMessage::Notification(notification) => { - processor - .lock() - .await - .process_notification(notification) - .await; - } - JSONRPCMessage::Error(err) => { - processor.lock().await.process_error(err).await; - } - }, - } - } - - info!("incoming processor task exited (channel closed)"); - } - }); - - let outgoing_handle = tokio::spawn({ - let processor = Arc::clone(&processor); - let connections = Arc::clone(&connections); - async move { - let mut listen_for_threads = true; - loop { - tokio::select! { + } envelope = outgoing_rx.recv() => { let Some(envelope) = envelope else { break; }; - let mut connections = connections.lock().await; route_outgoing_envelope(&mut connections, envelope).await; } created = thread_created_rx.recv(), if listen_for_threads => { match created { Ok(thread_id) => { - let should_attach = { - let connections = connections.lock().await; - has_initialized_connections(&connections) - }; - if should_attach { - let processor = Arc::clone(&processor); - tokio::spawn(async move { - processor - .lock() - .await - .try_attach_thread_listener(thread_id) - .await; - }); + if has_initialized_connections(&connections) { + processor.try_attach_thread_listener(thread_id).await; } } Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => { @@ -484,15 +426,13 @@ pub async fn run_main_with_transport( } } - info!("outgoing router task exited (channel closed)"); + info!("processor task exited (channel closed)"); } }); drop(transport_event_tx); - let _ = incoming_handle.await; - outgoing_handle.abort(); - let _ = outgoing_handle.await; + let _ = processor_handle.await; if let Some(handle) = websocket_accept_handle { handle.abort(); From 6cc79c4c0e8565604c8ee941176610123caf6618 Mon Sep 17 00:00:00 2001 From: Max Johnson Date: Wed, 11 Feb 2026 09:31:14 -0800 Subject: [PATCH 13/16] Revert "ws fix" This reverts commit b254b72cb4aea8c098376b3b76b9111efca8fe66. --- codex-rs/app-server/src/transport.rs | 112 +++++++-------------------- 1 file changed, 30 insertions(+), 82 deletions(-) diff --git a/codex-rs/app-server/src/transport.rs b/codex-rs/app-server/src/transport.rs index 4fa636acb28..1dcf5442bf0 100644 --- a/codex-rs/app-server/src/transport.rs +++ b/codex-rs/app-server/src/transport.rs @@ -279,98 +279,46 @@ async fn run_websocket_connection( } let (mut websocket_writer, mut websocket_reader) = websocket_stream.split(); - let (websocket_outgoing_tx, mut websocket_outgoing_rx) = - mpsc::channel::(CHANNEL_CAPACITY); - - let mut outgoing_handle = tokio::spawn(async move { - while let Some(outgoing_message) = websocket_outgoing_rx.recv().await { - if websocket_writer.send(outgoing_message).await.is_err() { - break; - } - } - }); - - let websocket_outgoing_tx_for_forwarder = websocket_outgoing_tx.clone(); - let mut outgoing_forwarder_handle = tokio::spawn(async move { - while let Some(outgoing_message) = writer_rx.recv().await { - let Some(json) = serialize_outgoing_message(outgoing_message) else { - continue; - }; - if websocket_outgoing_tx_for_forwarder - .send(WebSocketMessage::Text(json.into())) - .await - .is_err() - { - break; - } - } - }); - - let transport_event_tx_for_incoming = transport_event_tx.clone(); - let mut incoming_handle = tokio::spawn(async move { - loop { - match websocket_reader.next().await { - Some(Ok(WebSocketMessage::Text(text))) => { - if !forward_incoming_message( - &transport_event_tx_for_incoming, - connection_id, - &text, - ) - .await - { - break; - } + loop { + tokio::select! { + outgoing_message = writer_rx.recv() => { + let Some(outgoing_message) = outgoing_message else { + break; + }; + let Some(json) = serialize_outgoing_message(outgoing_message) else { + continue; + }; + if websocket_writer.send(WebSocketMessage::Text(json.into())).await.is_err() { + break; } - Some(Ok(WebSocketMessage::Ping(payload))) => { - match websocket_outgoing_tx.try_send(WebSocketMessage::Pong(payload)) { - Ok(()) => {} - Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { - debug!("dropping websocket pong because outgoing queue is full"); + } + incoming_message = websocket_reader.next() => { + match incoming_message { + Some(Ok(WebSocketMessage::Text(text))) => { + if !forward_incoming_message(&transport_event_tx, connection_id, &text).await { + break; } - Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { + } + Some(Ok(WebSocketMessage::Ping(payload))) => { + if websocket_writer.send(WebSocketMessage::Pong(payload)).await.is_err() { break; } } + Some(Ok(WebSocketMessage::Pong(_))) => {} + Some(Ok(WebSocketMessage::Close(_))) | None => break, + Some(Ok(WebSocketMessage::Binary(_))) => { + warn!("dropping unsupported binary websocket message"); + } + Some(Ok(WebSocketMessage::Frame(_))) => {} + Some(Err(err)) => { + warn!("websocket receive error: {err}"); + break; + } } - Some(Ok(WebSocketMessage::Pong(_))) => {} - Some(Ok(WebSocketMessage::Close(_))) | None => break, - Some(Ok(WebSocketMessage::Binary(_))) => { - warn!("dropping unsupported binary websocket message"); - } - Some(Ok(WebSocketMessage::Frame(_))) => {} - Some(Err(err)) => { - warn!("websocket receive error: {err}"); - break; - } - } - } - }); - - tokio::select! { - join_result = &mut outgoing_handle => { - if let Err(err) = join_result { - warn!("websocket outgoing task failed: {err}"); - } - } - join_result = &mut outgoing_forwarder_handle => { - if let Err(err) = join_result { - warn!("websocket outgoing forwarder task failed: {err}"); - } - } - join_result = &mut incoming_handle => { - if let Err(err) = join_result { - warn!("websocket incoming task failed: {err}"); } } } - outgoing_handle.abort(); - outgoing_forwarder_handle.abort(); - incoming_handle.abort(); - let _ = outgoing_handle.await; - let _ = outgoing_forwarder_handle.await; - let _ = incoming_handle.await; - let _ = transport_event_tx .send(TransportEvent::ConnectionClosed { connection_id }) .await; From b3ae8dc8b96b0313301d7065996f4032739dab13 Mon Sep 17 00:00:00 2001 From: Max Johnson Date: Wed, 11 Feb 2026 09:35:38 -0800 Subject: [PATCH 14/16] lint --- codex-rs/app-server/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 14d9c85bd84..dad90954b8f 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -7,9 +7,9 @@ use codex_core::config::ConfigBuilder; use codex_core::config_loader::CloudRequirementsLoader; use codex_core::config_loader::ConfigLayerStackOrdering; use codex_core::config_loader::LoaderOverrides; +use codex_utils_cli::CliConfigOverrides; use std::collections::HashMap; use std::collections::VecDeque; -use codex_utils_cli::CliConfigOverrides; use std::io::ErrorKind; use std::io::Result as IoResult; use std::path::PathBuf; From 8a586e8ca3cdafdb6d5deb74a91f65215040966c Mon Sep 17 00:00:00 2001 From: Max Johnson Date: Wed, 11 Feb 2026 09:46:59 -0800 Subject: [PATCH 15/16] reintroduce optout --- codex-rs/app-server/src/lib.rs | 36 ++++++++++++++++- codex-rs/app-server/src/transport.rs | 40 +++++++++++++------ .../app-server/tests/suite/v2/analytics.rs | 5 ++- .../app-server/tests/suite/v2/config_rpc.rs | 40 +++++++++++++++---- 4 files changed, 98 insertions(+), 23 deletions(-) diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index dad90954b8f..2a31b2053e7 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -9,11 +9,13 @@ use codex_core::config_loader::ConfigLayerStackOrdering; use codex_core::config_loader::LoaderOverrides; use codex_utils_cli::CliConfigOverrides; use std::collections::HashMap; +use std::collections::HashSet; use std::collections::VecDeque; use std::io::ErrorKind; use std::io::Result as IoResult; use std::path::PathBuf; use std::sync::Arc; +use std::sync::RwLock; use std::sync::atomic::AtomicBool; use crate::message_processor::MessageProcessor; @@ -79,6 +81,7 @@ enum OutboundControlEvent { connection_id: ConnectionId, writer: mpsc::Sender, initialized: Arc, + opted_out_notification_methods: Arc>>, }, /// Remove state for a closed/disconnected connection. Closed { connection_id: ConnectionId }, @@ -381,10 +384,15 @@ pub async fn run_main_with_transport( connection_id, writer, initialized, + opted_out_notification_methods, } => { outbound_connections.insert( connection_id, - OutboundConnectionState::new(writer, initialized), + OutboundConnectionState::new( + writer, + initialized, + opted_out_notification_methods, + ), ); } OutboundControlEvent::Closed { connection_id } => { @@ -449,18 +457,29 @@ pub async fn run_main_with_transport( match event { TransportEvent::ConnectionOpened { connection_id, writer } => { let outbound_initialized = Arc::new(AtomicBool::new(false)); + let outbound_opted_out_notification_methods = + Arc::new(RwLock::new(HashSet::new())); if outbound_control_tx .send(OutboundControlEvent::Opened { connection_id, writer, initialized: Arc::clone(&outbound_initialized), + opted_out_notification_methods: Arc::clone( + &outbound_opted_out_notification_methods, + ), }) .await .is_err() { break; } - connections.insert(connection_id, ConnectionState::new(outbound_initialized)); + connections.insert( + connection_id, + ConnectionState::new( + outbound_initialized, + outbound_opted_out_notification_methods, + ), + ); } TransportEvent::ConnectionClosed { connection_id } => { if outbound_control_tx @@ -491,6 +510,19 @@ pub async fn run_main_with_transport( &connection_state.outbound_initialized, ) .await; + if let Ok(mut opted_out_notification_methods) = connection_state + .outbound_opted_out_notification_methods + .write() + { + *opted_out_notification_methods = connection_state + .session + .opted_out_notification_methods + .clone(); + } else { + warn!( + "failed to update outbound opted-out notifications" + ); + } if !was_initialized && connection_state.session.initialized { processor.send_initialize_notifications().await; } diff --git a/codex-rs/app-server/src/transport.rs b/codex-rs/app-server/src/transport.rs index fe2ba10c38f..cbfd263a555 100644 --- a/codex-rs/app-server/src/transport.rs +++ b/codex-rs/app-server/src/transport.rs @@ -12,11 +12,13 @@ use owo_colors::OwoColorize; use owo_colors::Stream; use owo_colors::Style; use std::collections::HashMap; +use std::collections::HashSet; use std::io::ErrorKind; use std::io::Result as IoResult; use std::net::SocketAddr; use std::str::FromStr; use std::sync::Arc; +use std::sync::RwLock; use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; @@ -145,13 +147,18 @@ pub(crate) enum TransportEvent { pub(crate) struct ConnectionState { pub(crate) outbound_initialized: Arc, + pub(crate) outbound_opted_out_notification_methods: Arc>>, pub(crate) session: ConnectionSessionState, } impl ConnectionState { - pub(crate) fn new(outbound_initialized: Arc) -> Self { + pub(crate) fn new( + outbound_initialized: Arc, + outbound_opted_out_notification_methods: Arc>>, + ) -> Self { Self { outbound_initialized, + outbound_opted_out_notification_methods, session: ConnectionSessionState::default(), } } @@ -159,13 +166,19 @@ impl ConnectionState { pub(crate) struct OutboundConnectionState { pub(crate) initialized: Arc, + pub(crate) opted_out_notification_methods: Arc>>, pub(crate) writer: mpsc::Sender, } impl OutboundConnectionState { - pub(crate) fn new(writer: mpsc::Sender, initialized: Arc) -> Self { + pub(crate) fn new( + writer: mpsc::Sender, + initialized: Arc, + opted_out_notification_methods: Arc>>, + ) -> Self { Self { initialized, + opted_out_notification_methods, writer, } } @@ -428,21 +441,22 @@ fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option bool { + let Ok(opted_out_notification_methods) = connection_state.opted_out_notification_methods.read() + else { + warn!("failed to read outbound opted-out notifications"); + return false; + }; match message { OutgoingMessage::AppServerNotification(notification) => { let method = notification.to_string(); - connection_state - .session - .opted_out_notification_methods - .contains(method.as_str()) + opted_out_notification_methods.contains(method.as_str()) + } + OutgoingMessage::Notification(notification) => { + opted_out_notification_methods.contains(notification.method.as_str()) } - OutgoingMessage::Notification(notification) => connection_state - .session - .opted_out_notification_methods - .contains(notification.method.as_str()), _ => false, } } @@ -473,7 +487,9 @@ pub(crate) async fn route_outgoing_envelope( let target_connections: Vec = connections .iter() .filter_map(|(connection_id, connection_state)| { - if connection_state.initialized.load(Ordering::Acquire) { + if connection_state.initialized.load(Ordering::Acquire) + && !should_skip_notification_for_connection(connection_state, &message) + { Some(*connection_id) } else { None diff --git a/codex-rs/app-server/tests/suite/v2/analytics.rs b/codex-rs/app-server/tests/suite/v2/analytics.rs index e18a0d3c849..0d05d644658 100644 --- a/codex-rs/app-server/tests/suite/v2/analytics.rs +++ b/codex-rs/app-server/tests/suite/v2/analytics.rs @@ -36,8 +36,9 @@ async fn app_server_default_analytics_disabled_without_flag() -> Result<()> { .map_err(|err| anyhow::anyhow!(err.to_string()))?; // With analytics unset in the config and the default flag is false, metrics are disabled. - // No provider is built. - assert_eq!(provider.is_none(), true); + // A provider may still exist for non-metrics telemetry, so check metrics specifically. + let has_metrics = provider.as_ref().and_then(|otel| otel.metrics()).is_some(); + assert_eq!(has_metrics, false); Ok(()) } diff --git a/codex-rs/app-server/tests/suite/v2/config_rpc.rs b/codex-rs/app-server/tests/suite/v2/config_rpc.rs index 4129564b16b..cceadbb3377 100644 --- a/codex-rs/app-server/tests/suite/v2/config_rpc.rs +++ b/codex-rs/app-server/tests/suite/v2/config_rpc.rs @@ -560,9 +560,22 @@ fn assert_layers_user_then_optional_system( layers: &[codex_app_server_protocol::ConfigLayer], user_file: AbsolutePathBuf, ) -> Result<()> { - assert_eq!(layers.len(), 2); - assert_eq!(layers[0].name, ConfigLayerSource::User { file: user_file }); - assert!(matches!(layers[1].name, ConfigLayerSource::System { .. })); + let mut first_index = 0; + if matches!( + layers.first().map(|layer| &layer.name), + Some(ConfigLayerSource::LegacyManagedConfigTomlFromMdm { .. }) + ) { + first_index = 1; + } + assert_eq!(layers.len(), first_index + 2); + assert_eq!( + layers[first_index].name, + ConfigLayerSource::User { file: user_file } + ); + assert!(matches!( + layers[first_index + 1].name, + ConfigLayerSource::System { .. } + )); Ok(()) } @@ -571,12 +584,25 @@ fn assert_layers_managed_user_then_optional_system( managed_file: AbsolutePathBuf, user_file: AbsolutePathBuf, ) -> Result<()> { - assert_eq!(layers.len(), 3); + let mut first_index = 0; + if matches!( + layers.first().map(|layer| &layer.name), + Some(ConfigLayerSource::LegacyManagedConfigTomlFromMdm { .. }) + ) { + first_index = 1; + } + assert_eq!(layers.len(), first_index + 3); assert_eq!( - layers[0].name, + layers[first_index].name, ConfigLayerSource::LegacyManagedConfigTomlFromFile { file: managed_file } ); - assert_eq!(layers[1].name, ConfigLayerSource::User { file: user_file }); - assert!(matches!(layers[2].name, ConfigLayerSource::System { .. })); + assert_eq!( + layers[first_index + 1].name, + ConfigLayerSource::User { file: user_file } + ); + assert!(matches!( + layers[first_index + 2].name, + ConfigLayerSource::System { .. } + )); Ok(()) } From 047f59c6d79b33abfe7c518e4d95db93547632fa Mon Sep 17 00:00:00 2001 From: Max Johnson Date: Wed, 11 Feb 2026 09:53:58 -0800 Subject: [PATCH 16/16] fix --- codex-rs/app-server/tests/suite/v2/config_rpc.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codex-rs/app-server/tests/suite/v2/config_rpc.rs b/codex-rs/app-server/tests/suite/v2/config_rpc.rs index cceadbb3377..de4c51cde47 100644 --- a/codex-rs/app-server/tests/suite/v2/config_rpc.rs +++ b/codex-rs/app-server/tests/suite/v2/config_rpc.rs @@ -563,7 +563,7 @@ fn assert_layers_user_then_optional_system( let mut first_index = 0; if matches!( layers.first().map(|layer| &layer.name), - Some(ConfigLayerSource::LegacyManagedConfigTomlFromMdm { .. }) + Some(ConfigLayerSource::LegacyManagedConfigTomlFromMdm) ) { first_index = 1; } @@ -587,7 +587,7 @@ fn assert_layers_managed_user_then_optional_system( let mut first_index = 0; if matches!( layers.first().map(|layer| &layer.name), - Some(ConfigLayerSource::LegacyManagedConfigTomlFromMdm { .. }) + Some(ConfigLayerSource::LegacyManagedConfigTomlFromMdm) ) { first_index = 1; }