diff --git a/codex-rs/app-server/README.md b/codex-rs/app-server/README.md index 66d4a501ec5..c40693c601a 100644 --- a/codex-rs/app-server/README.md +++ b/codex-rs/app-server/README.md @@ -28,6 +28,12 @@ Supported transports: Websocket transport is currently experimental and unsupported. Do not rely on it for production workloads. +Backpressure behavior: + +- The server uses bounded queues between transport ingress, request processing, and outbound writes. +- When request ingress is saturated, new requests are rejected with a JSON-RPC error code `-32001` and message `"Server overloaded; retry later."`. +- Clients should treat this as retryable and use exponential backoff with jitter. + ## Message Schema Currently, you can dump a TypeScript version of the schema using `codex app-server generate-ts`, or a JSON Schema bundle via `codex app-server generate-json-schema`. Each output is specific to the version of Codex you used to run the command, so the generated artifacts are guaranteed to match that version. diff --git a/codex-rs/app-server/src/error_code.rs b/codex-rs/app-server/src/error_code.rs index 1ffd889d404..ca93b2f2d33 100644 --- a/codex-rs/app-server/src/error_code.rs +++ b/codex-rs/app-server/src/error_code.rs @@ -1,2 +1,3 @@ pub(crate) const INVALID_REQUEST_ERROR_CODE: i64 = -32600; pub(crate) const INTERNAL_ERROR_CODE: i64 = -32603; +pub(crate) const OVERLOADED_ERROR_CODE: i64 = -32001; diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index ad049ad3055..0b23fed74ed 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -9,10 +9,12 @@ use codex_core::config_loader::CloudRequirementsLoader; use codex_core::config_loader::ConfigLayerStackOrdering; use codex_core::config_loader::LoaderOverrides; use std::collections::HashMap; +use std::collections::VecDeque; use std::io::ErrorKind; use std::io::Result as IoResult; use std::path::PathBuf; use std::sync::Arc; +use std::sync::atomic::AtomicBool; use crate::message_processor::MessageProcessor; use crate::message_processor::MessageProcessorArgs; @@ -21,6 +23,7 @@ use crate::outgoing_message::OutgoingEnvelope; use crate::outgoing_message::OutgoingMessageSender; use crate::transport::CHANNEL_CAPACITY; use crate::transport::ConnectionState; +use crate::transport::OutboundConnectionState; use crate::transport::TransportEvent; use crate::transport::has_initialized_connections; use crate::transport::route_outgoing_envelope; @@ -61,6 +64,26 @@ mod transport; pub use crate::transport::AppServerTransport; +/// Control-plane messages from the processor/transport side to the outbound router task. +/// +/// `run_main_with_transport` now uses two loops/tasks: +/// - processor loop: handles incoming JSON-RPC and request dispatch +/// - outbound loop: performs potentially slow writes to per-connection writers +/// +/// `OutboundControlEvent` keeps those loops coordinated without sharing mutable +/// connection state directly. In particular, the outbound loop needs to know +/// when a connection opens/closes so it can route messages correctly. +enum OutboundControlEvent { + /// Register a new writer for an opened connection. + Opened { + connection_id: ConnectionId, + writer: mpsc::Sender, + initialized: Arc, + }, + /// Remove state for a closed/disconnected connection. + Closed { connection_id: ConnectionId }, +} + fn config_warning_from_error( summary: impl Into, err: &std::io::Error, @@ -197,6 +220,8 @@ pub async fn run_main_with_transport( let (transport_event_tx, mut transport_event_rx) = mpsc::channel::(CHANNEL_CAPACITY); let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY); + let (outbound_control_tx, mut outbound_control_rx) = + mpsc::channel::(CHANNEL_CAPACITY); let mut stdio_handles = Vec::>::new(); let mut websocket_accept_handle = None; @@ -336,8 +361,65 @@ pub async fn run_main_with_transport( } } + let transport_event_tx_for_outbound = transport_event_tx.clone(); + let outbound_handle = tokio::spawn(async move { + let mut outbound_connections = HashMap::::new(); + let mut pending_closed_connections = VecDeque::::new(); + loop { + tokio::select! { + biased; + event = outbound_control_rx.recv() => { + let Some(event) = event else { + break; + }; + match event { + OutboundControlEvent::Opened { + connection_id, + writer, + initialized, + } => { + outbound_connections.insert( + connection_id, + OutboundConnectionState::new(writer, initialized), + ); + } + OutboundControlEvent::Closed { connection_id } => { + outbound_connections.remove(&connection_id); + } + } + } + envelope = outgoing_rx.recv() => { + let Some(envelope) = envelope else { + break; + }; + let disconnected_connections = + route_outgoing_envelope(&mut outbound_connections, envelope).await; + pending_closed_connections.extend(disconnected_connections); + } + } + + while let Some(connection_id) = pending_closed_connections.front().copied() { + match transport_event_tx_for_outbound + .try_send(TransportEvent::ConnectionClosed { connection_id }) + { + Ok(()) => { + pending_closed_connections.pop_front(); + } + Err(mpsc::error::TrySendError::Full(_)) => { + break; + } + Err(mpsc::error::TrySendError::Closed(_)) => { + return; + } + } + } + } + info!("outbound router task exited (channel closed)"); + }); + let processor_handle = tokio::spawn({ let outgoing_message_sender = Arc::new(OutgoingMessageSender::new(outgoing_tx)); + let outbound_control_tx = outbound_control_tx; let cli_overrides: Vec<(String, TomlValue)> = cli_kv_overrides.clone(); let loader_overrides = loader_overrides_for_config_api; let mut processor = MessageProcessor::new(MessageProcessorArgs { @@ -362,9 +444,28 @@ pub async fn run_main_with_transport( }; match event { TransportEvent::ConnectionOpened { connection_id, writer } => { - connections.insert(connection_id, ConnectionState::new(writer)); + let outbound_initialized = Arc::new(AtomicBool::new(false)); + if outbound_control_tx + .send(OutboundControlEvent::Opened { + connection_id, + writer, + initialized: Arc::clone(&outbound_initialized), + }) + .await + .is_err() + { + break; + } + connections.insert(connection_id, ConnectionState::new(outbound_initialized)); } TransportEvent::ConnectionClosed { connection_id } => { + if outbound_control_tx + .send(OutboundControlEvent::Closed { connection_id }) + .await + .is_err() + { + break; + } connections.remove(&connection_id); if shutdown_when_no_connections && connections.is_empty() { break; @@ -377,13 +478,18 @@ pub async fn run_main_with_transport( warn!("dropping request from unknown connection: {:?}", connection_id); continue; }; + let was_initialized = connection_state.session.initialized; processor .process_request( connection_id, request, &mut connection_state.session, + &connection_state.outbound_initialized, ) .await; + if !was_initialized && connection_state.session.initialized { + processor.send_initialize_notifications().await; + } } JSONRPCMessage::Response(response) => { processor.process_response(response).await; @@ -398,12 +504,6 @@ pub async fn run_main_with_transport( } } } - envelope = outgoing_rx.recv() => { - let Some(envelope) = envelope else { - break; - }; - route_outgoing_envelope(&mut connections, envelope).await; - } created = thread_created_rx.recv(), if listen_for_threads => { match created { Ok(thread_id) => { @@ -433,6 +533,7 @@ pub async fn run_main_with_transport( drop(transport_event_tx); let _ = processor_handle.await; + let _ = outbound_handle.await; if let Some(handle) = websocket_accept_handle { handle.abort(); diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 26da44df311..b962d52902e 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -1,6 +1,8 @@ use std::path::PathBuf; use std::sync::Arc; use std::sync::RwLock; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; use crate::codex_message_processor::CodexMessageProcessor; use crate::codex_message_processor::CodexMessageProcessorArgs; @@ -191,6 +193,7 @@ impl MessageProcessor { connection_id: ConnectionId, request: JSONRPCRequest, session: &mut ConnectionSessionState, + outbound_initialized: &AtomicBool, ) { let request_id = ConnectionRequestId { connection_id, @@ -286,14 +289,7 @@ impl MessageProcessor { self.outgoing.send_response(request_id, response).await; session.initialized = true; - for notification in self.config_warnings.iter().cloned() { - self.outgoing - .send_server_notification(ServerNotification::ConfigWarning( - notification, - )) - .await; - } - + outbound_initialized.store(true, Ordering::Release); return; } } @@ -381,6 +377,14 @@ impl MessageProcessor { self.codex_message_processor.thread_created_receiver() } + pub(crate) async fn send_initialize_notifications(&self) { + for notification in self.config_warnings.iter().cloned() { + self.outgoing + .send_server_notification(ServerNotification::ConfigWarning(notification)) + .await; + } + } + pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) { self.codex_message_processor .try_attach_thread_listener(thread_id) diff --git a/codex-rs/app-server/src/transport.rs b/codex-rs/app-server/src/transport.rs index 39fd13212cf..d70eb3ffd8d 100644 --- a/codex-rs/app-server/src/transport.rs +++ b/codex-rs/app-server/src/transport.rs @@ -1,7 +1,10 @@ +use crate::error_code::OVERLOADED_ERROR_CODE; use crate::message_processor::ConnectionSessionState; use crate::outgoing_message::ConnectionId; use crate::outgoing_message::OutgoingEnvelope; +use crate::outgoing_message::OutgoingError; use crate::outgoing_message::OutgoingMessage; +use codex_app_server_protocol::JSONRPCErrorError; use codex_app_server_protocol::JSONRPCMessage; use futures::SinkExt; use futures::StreamExt; @@ -14,6 +17,7 @@ use std::io::Result as IoResult; use std::net::SocketAddr; use std::str::FromStr; use std::sync::Arc; +use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; use tokio::io::AsyncBufReadExt; @@ -140,25 +144,40 @@ pub(crate) enum TransportEvent { } pub(crate) struct ConnectionState { - pub(crate) writer: mpsc::Sender, + pub(crate) outbound_initialized: Arc, pub(crate) session: ConnectionSessionState, } impl ConnectionState { - pub(crate) fn new(writer: mpsc::Sender) -> Self { + pub(crate) fn new(outbound_initialized: Arc) -> Self { Self { - writer, + outbound_initialized, session: ConnectionSessionState::default(), } } } +pub(crate) struct OutboundConnectionState { + pub(crate) initialized: Arc, + pub(crate) writer: mpsc::Sender, +} + +impl OutboundConnectionState { + pub(crate) fn new(writer: mpsc::Sender, initialized: Arc) -> Self { + Self { + initialized, + writer, + } + } +} + pub(crate) async fn start_stdio_connection( transport_event_tx: mpsc::Sender, stdio_handles: &mut Vec>, ) -> IoResult<()> { let connection_id = ConnectionId(0); let (writer_tx, mut writer_rx) = mpsc::channel::(CHANNEL_CAPACITY); + let writer_tx_for_reader = writer_tx.clone(); transport_event_tx .send(TransportEvent::ConnectionOpened { connection_id, @@ -178,6 +197,7 @@ pub(crate) async fn start_stdio_connection( Ok(Some(line)) => { if !forward_incoming_message( &transport_event_tx_for_reader, + &writer_tx_for_reader, connection_id, &line, ) @@ -267,6 +287,7 @@ async fn run_websocket_connection( }; let (writer_tx, mut writer_rx) = mpsc::channel::(CHANNEL_CAPACITY); + let writer_tx_for_reader = writer_tx.clone(); if transport_event_tx .send(TransportEvent::ConnectionOpened { connection_id, @@ -295,7 +316,14 @@ async fn run_websocket_connection( incoming_message = websocket_reader.next() => { match incoming_message { Some(Ok(WebSocketMessage::Text(text))) => { - if !forward_incoming_message(&transport_event_tx, connection_id, &text).await { + if !forward_incoming_message( + &transport_event_tx, + &writer_tx_for_reader, + connection_id, + &text, + ) + .await + { break; } } @@ -326,17 +354,14 @@ async fn run_websocket_connection( async fn forward_incoming_message( transport_event_tx: &mpsc::Sender, + writer: &mpsc::Sender, connection_id: ConnectionId, payload: &str, ) -> bool { match serde_json::from_str::(payload) { - Ok(message) => transport_event_tx - .send(TransportEvent::IncomingMessage { - connection_id, - message, - }) - .await - .is_ok(), + Ok(message) => { + enqueue_incoming_message(transport_event_tx, writer, connection_id, message).await + } Err(err) => { error!("Failed to deserialize JSONRPCMessage: {err}"); true @@ -344,6 +369,47 @@ async fn forward_incoming_message( } } +async fn enqueue_incoming_message( + transport_event_tx: &mpsc::Sender, + writer: &mpsc::Sender, + connection_id: ConnectionId, + message: JSONRPCMessage, +) -> bool { + let event = TransportEvent::IncomingMessage { + connection_id, + message, + }; + match transport_event_tx.try_send(event) { + Ok(()) => true, + Err(mpsc::error::TrySendError::Closed(_)) => false, + Err(mpsc::error::TrySendError::Full(TransportEvent::IncomingMessage { + connection_id, + message: JSONRPCMessage::Request(request), + })) => { + let overload_error = OutgoingMessage::Error(OutgoingError { + id: request.id, + error: JSONRPCErrorError { + code: OVERLOADED_ERROR_CODE, + message: "Server overloaded; retry later.".to_string(), + data: None, + }, + }); + match writer.try_send(overload_error) { + Ok(()) => true, + Err(mpsc::error::TrySendError::Closed(_)) => false, + Err(mpsc::error::TrySendError::Full(_overload_error)) => { + warn!( + "dropping overload response for connection {:?}: outbound queue is full", + connection_id + ); + true + } + } + } + Err(mpsc::error::TrySendError::Full(event)) => transport_event_tx.send(event).await.is_ok(), + } +} + fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option { let value = match serde_json::to_value(outgoing_message) { Ok(value) => value, @@ -362,9 +428,10 @@ fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option, + connections: &mut HashMap, envelope: OutgoingEnvelope, -) { +) -> Vec { + let mut disconnected = Vec::new(); match envelope { OutgoingEnvelope::ToConnection { connection_id, @@ -375,17 +442,18 @@ pub(crate) async fn route_outgoing_envelope( "dropping message for disconnected connection: {:?}", connection_id ); - return; + return disconnected; }; if connection_state.writer.send(message).await.is_err() { connections.remove(&connection_id); + disconnected.push(connection_id); } } OutgoingEnvelope::Broadcast { message } => { let target_connections: Vec = connections .iter() .filter_map(|(connection_id, connection_state)| { - if connection_state.session.initialized { + if connection_state.initialized.load(Ordering::Acquire) { Some(*connection_id) } else { None @@ -399,10 +467,12 @@ pub(crate) async fn route_outgoing_envelope( }; if connection_state.writer.send(message.clone()).await.is_err() { connections.remove(&connection_id); + disconnected.push(connection_id); } } } } + disconnected } pub(crate) fn has_initialized_connections( @@ -416,7 +486,9 @@ pub(crate) fn has_initialized_connections( #[cfg(test)] mod tests { use super::*; + use crate::error_code::OVERLOADED_ERROR_CODE; use pretty_assertions::assert_eq; + use serde_json::json; #[test] fn app_server_transport_parses_stdio_listen_url() { @@ -456,4 +528,186 @@ mod tests { "unsupported --listen URL `http://127.0.0.1:1234`; expected `stdio://` or `ws://IP:PORT`" ); } + + #[tokio::test] + async fn enqueue_incoming_request_returns_overload_error_when_queue_is_full() { + let connection_id = ConnectionId(42); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + + let first_message = + JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + transport_event_tx + .send(TransportEvent::IncomingMessage { + connection_id, + message: first_message.clone(), + }) + .await + .expect("queue should accept first message"); + + let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { + id: codex_app_server_protocol::RequestId::Integer(7), + method: "config/read".to_string(), + params: Some(json!({ "includeLayers": false })), + }); + assert!( + enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request).await + ); + + let queued_event = transport_event_rx + .recv() + .await + .expect("first event should stay queued"); + match queued_event { + TransportEvent::IncomingMessage { + connection_id: queued_connection_id, + message, + } => { + assert_eq!(queued_connection_id, connection_id); + assert_eq!(message, first_message); + } + _ => panic!("expected queued incoming message"), + } + + let overload = writer_rx + .recv() + .await + .expect("request should receive overload error"); + let overload_json = serde_json::to_value(overload).expect("serialize overload error"); + assert_eq!( + overload_json, + json!({ + "id": 7, + "error": { + "code": OVERLOADED_ERROR_CODE, + "message": "Server overloaded; retry later." + } + }) + ); + } + + #[tokio::test] + async fn enqueue_incoming_response_waits_instead_of_dropping_when_queue_is_full() { + let connection_id = ConnectionId(42); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1); + let (writer_tx, _writer_rx) = mpsc::channel(1); + + let first_message = + JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + transport_event_tx + .send(TransportEvent::IncomingMessage { + connection_id, + message: first_message.clone(), + }) + .await + .expect("queue should accept first message"); + + let response = JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse { + id: codex_app_server_protocol::RequestId::Integer(7), + result: json!({"ok": true}), + }); + let transport_event_tx_for_enqueue = transport_event_tx.clone(); + let writer_tx_for_enqueue = writer_tx.clone(); + let enqueue_handle = tokio::spawn(async move { + enqueue_incoming_message( + &transport_event_tx_for_enqueue, + &writer_tx_for_enqueue, + connection_id, + response, + ) + .await + }); + + let queued_event = transport_event_rx + .recv() + .await + .expect("first event should be dequeued"); + match queued_event { + TransportEvent::IncomingMessage { + connection_id: queued_connection_id, + message, + } => { + assert_eq!(queued_connection_id, connection_id); + assert_eq!(message, first_message); + } + _ => panic!("expected queued incoming message"), + } + + let enqueue_result = enqueue_handle.await.expect("enqueue task should not panic"); + assert!(enqueue_result); + + let forwarded_event = transport_event_rx + .recv() + .await + .expect("response should be forwarded instead of dropped"); + match forwarded_event { + TransportEvent::IncomingMessage { + connection_id: queued_connection_id, + message: + JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse { id, result }), + } => { + assert_eq!(queued_connection_id, connection_id); + assert_eq!(id, codex_app_server_protocol::RequestId::Integer(7)); + assert_eq!(result, json!({"ok": true})); + } + _ => panic!("expected forwarded response message"), + } + } + + #[tokio::test] + async fn enqueue_incoming_request_does_not_block_when_writer_queue_is_full() { + let connection_id = ConnectionId(42); + let (transport_event_tx, _transport_event_rx) = mpsc::channel(1); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + + transport_event_tx + .send(TransportEvent::IncomingMessage { + connection_id, + message: JSONRPCMessage::Notification( + codex_app_server_protocol::JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }, + ), + }) + .await + .expect("transport queue should accept first message"); + + writer_tx + .send(OutgoingMessage::Notification( + crate::outgoing_message::OutgoingNotification { + method: "queued".to_string(), + params: None, + }, + )) + .await + .expect("writer queue should accept first message"); + + let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { + id: codex_app_server_protocol::RequestId::Integer(7), + method: "config/read".to_string(), + params: Some(json!({ "includeLayers": false })), + }); + + let enqueue_result = tokio::time::timeout( + std::time::Duration::from_millis(100), + enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request), + ) + .await + .expect("enqueue should not block while writer queue is full"); + assert!(enqueue_result); + + let queued_outgoing = writer_rx + .recv() + .await + .expect("writer queue should still contain original message"); + let queued_json = serde_json::to_value(queued_outgoing).expect("serialize queued message"); + assert_eq!(queued_json, json!({ "method": "queued" })); + } }