diff --git a/codex-rs/app-server/src/bespoke_event_handling.rs b/codex-rs/app-server/src/bespoke_event_handling.rs index e0fc1cfa439..3892f88860a 100644 --- a/codex-rs/app-server/src/bespoke_event_handling.rs +++ b/codex-rs/app-server/src/bespoke_event_handling.rs @@ -4,7 +4,7 @@ use crate::codex_message_processor::read_summary_from_rollout; use crate::codex_message_processor::summary_to_thread; use crate::error_code::INTERNAL_ERROR_CODE; use crate::error_code::INVALID_REQUEST_ERROR_CODE; -use crate::outgoing_message::OutgoingMessageSender; +use crate::outgoing_message::ThreadScopedOutgoingMessageSender; use crate::thread_state::ThreadState; use crate::thread_state::TurnSummary; use codex_app_server_protocol::AccountRateLimitsUpdatedNotification; @@ -107,7 +107,7 @@ pub(crate) async fn apply_bespoke_event_handling( event: Event, conversation_id: ThreadId, conversation: Arc, - outgoing: Arc, + outgoing: ThreadScopedOutgoingMessageSender, thread_state: Arc>, api_version: ApiVersion, fallback_model_provider: String, @@ -850,7 +850,7 @@ pub(crate) async fn apply_bespoke_event_handling( conversation_id, &event_turn_id, raw_response_item_event.item, - outgoing.as_ref(), + &outgoing, ) .await; } @@ -899,7 +899,7 @@ pub(crate) async fn apply_bespoke_event_handling( changes, status, event_turn_id.clone(), - outgoing.as_ref(), + &outgoing, &thread_state, ) .await; @@ -1142,7 +1142,7 @@ pub(crate) async fn apply_bespoke_event_handling( &event_turn_id, turn_diff_event, api_version, - outgoing.as_ref(), + &outgoing, ) .await; } @@ -1152,7 +1152,7 @@ pub(crate) async fn apply_bespoke_event_handling( &event_turn_id, plan_update_event, api_version, - outgoing.as_ref(), + &outgoing, ) .await; } @@ -1166,7 +1166,7 @@ async fn handle_turn_diff( event_turn_id: &str, turn_diff_event: TurnDiffEvent, api_version: ApiVersion, - outgoing: &OutgoingMessageSender, + outgoing: &ThreadScopedOutgoingMessageSender, ) { if let ApiVersion::V2 = api_version { let notification = TurnDiffUpdatedNotification { @@ -1185,7 +1185,7 @@ async fn handle_turn_plan_update( event_turn_id: &str, plan_update_event: UpdatePlanArgs, api_version: ApiVersion, - outgoing: &OutgoingMessageSender, + outgoing: &ThreadScopedOutgoingMessageSender, ) { // `update_plan` is a todo/checklist tool; it is not related to plan-mode updates if let ApiVersion::V2 = api_version { @@ -1210,7 +1210,7 @@ async fn emit_turn_completed_with_status( event_turn_id: String, status: TurnStatus, error: Option, - outgoing: &OutgoingMessageSender, + outgoing: &ThreadScopedOutgoingMessageSender, ) { let notification = TurnCompletedNotification { thread_id: conversation_id.to_string(), @@ -1232,7 +1232,7 @@ async fn complete_file_change_item( changes: Vec, status: PatchApplyStatus, turn_id: String, - outgoing: &OutgoingMessageSender, + outgoing: &ThreadScopedOutgoingMessageSender, thread_state: &Arc>, ) { let mut state = thread_state.lock().await; @@ -1264,7 +1264,7 @@ async fn complete_command_execution_item( process_id: Option, command_actions: Vec, status: CommandExecutionStatus, - outgoing: &OutgoingMessageSender, + outgoing: &ThreadScopedOutgoingMessageSender, ) { let item = ThreadItem::CommandExecution { id: item_id, @@ -1292,7 +1292,7 @@ async fn maybe_emit_raw_response_item_completed( conversation_id: ThreadId, turn_id: &str, item: codex_protocol::models::ResponseItem, - outgoing: &OutgoingMessageSender, + outgoing: &ThreadScopedOutgoingMessageSender, ) { let ApiVersion::V2 = api_version else { return; @@ -1319,7 +1319,7 @@ async fn find_and_remove_turn_summary( async fn handle_turn_complete( conversation_id: ThreadId, event_turn_id: String, - outgoing: &OutgoingMessageSender, + outgoing: &ThreadScopedOutgoingMessageSender, thread_state: &Arc>, ) { let turn_summary = find_and_remove_turn_summary(conversation_id, thread_state).await; @@ -1335,7 +1335,7 @@ async fn handle_turn_complete( async fn handle_turn_interrupted( conversation_id: ThreadId, event_turn_id: String, - outgoing: &OutgoingMessageSender, + outgoing: &ThreadScopedOutgoingMessageSender, thread_state: &Arc>, ) { find_and_remove_turn_summary(conversation_id, thread_state).await; @@ -1354,7 +1354,7 @@ async fn handle_thread_rollback_failed( _conversation_id: ThreadId, message: String, thread_state: &Arc>, - outgoing: &OutgoingMessageSender, + outgoing: &ThreadScopedOutgoingMessageSender, ) { let pending_rollback = thread_state.lock().await.pending_rollbacks.take(); @@ -1376,7 +1376,7 @@ async fn handle_token_count_event( conversation_id: ThreadId, turn_id: String, token_count_event: TokenCountEvent, - outgoing: &OutgoingMessageSender, + outgoing: &ThreadScopedOutgoingMessageSender, ) { let TokenCountEvent { info, rate_limits } = token_count_event; if let Some(token_usage) = info.map(ThreadTokenUsage::from) { @@ -1633,7 +1633,7 @@ async fn on_file_change_request_approval_response( changes: Vec, receiver: oneshot::Receiver, codex: Arc, - outgoing: Arc, + outgoing: ThreadScopedOutgoingMessageSender, thread_state: Arc>, ) { let response = receiver.await; @@ -1666,7 +1666,7 @@ async fn on_file_change_request_approval_response( changes, status, event_turn_id.clone(), - outgoing.as_ref(), + &outgoing, &thread_state, ) .await; @@ -1693,7 +1693,7 @@ async fn on_command_execution_request_approval_response( command_actions: Vec, receiver: oneshot::Receiver, conversation: Arc, - outgoing: Arc, + outgoing: ThreadScopedOutgoingMessageSender, ) { let response = receiver.await; let (decision, completion_status) = match response { @@ -1748,7 +1748,7 @@ async fn on_command_execution_request_approval_response( None, command_actions.clone(), status, - outgoing.as_ref(), + &outgoing, ) .await; } @@ -1876,6 +1876,7 @@ async fn construct_mcp_tool_call_end_notification( mod tests { use super::*; use crate::CHANNEL_CAPACITY; + use crate::outgoing_message::ConnectionId; use crate::outgoing_message::OutgoingEnvelope; use crate::outgoing_message::OutgoingMessage; use crate::outgoing_message::OutgoingMessageSender; @@ -1914,9 +1915,7 @@ mod tests { .ok_or_else(|| anyhow!("should send one message"))?; match envelope { OutgoingEnvelope::Broadcast { message } => Ok(message), - OutgoingEnvelope::ToConnection { connection_id, .. } => { - bail!("unexpected targeted message for connection {connection_id:?}") - } + OutgoingEnvelope::ToConnection { message, .. } => Ok(message), } } @@ -2011,6 +2010,7 @@ mod tests { let event_turn_id = "complete1".to_string(); let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); let outgoing = Arc::new(OutgoingMessageSender::new(tx)); + let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]); let thread_state = new_thread_state(); handle_turn_complete( @@ -2051,6 +2051,7 @@ mod tests { .await; let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); let outgoing = Arc::new(OutgoingMessageSender::new(tx)); + let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]); handle_turn_interrupted( conversation_id, @@ -2090,6 +2091,7 @@ mod tests { .await; let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); let outgoing = Arc::new(OutgoingMessageSender::new(tx)); + let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]); handle_turn_complete( conversation_id, @@ -2122,7 +2124,8 @@ mod tests { #[tokio::test] async fn test_handle_turn_plan_update_emits_notification_for_v2() -> Result<()> { let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); - let outgoing = OutgoingMessageSender::new(tx); + let outgoing = Arc::new(OutgoingMessageSender::new(tx)); + let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]); let update = UpdatePlanArgs { explanation: Some("need plan".to_string()), plan: vec![ @@ -2172,6 +2175,7 @@ mod tests { let turn_id = "turn-123".to_string(); let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); let outgoing = Arc::new(OutgoingMessageSender::new(tx)); + let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]); let info = TokenUsageInfo { total_token_usage: TokenUsage { @@ -2255,6 +2259,7 @@ mod tests { let turn_id = "turn-456".to_string(); let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); let outgoing = Arc::new(OutgoingMessageSender::new(tx)); + let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]); handle_token_count_event( conversation_id, @@ -2321,6 +2326,7 @@ mod tests { let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); let outgoing = Arc::new(OutgoingMessageSender::new(tx)); + let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]); // Turn 1 on conversation A let a_turn1 = "a_turn1".to_string(); @@ -2542,7 +2548,8 @@ mod tests { #[tokio::test] async fn test_handle_turn_diff_emits_v2_notification() -> Result<()> { let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); - let outgoing = OutgoingMessageSender::new(tx); + let outgoing = Arc::new(OutgoingMessageSender::new(tx)); + let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]); let unified_diff = "--- a\n+++ b\n".to_string(); let conversation_id = ThreadId::new(); @@ -2575,7 +2582,8 @@ mod tests { #[tokio::test] async fn test_handle_turn_diff_is_noop_for_v1() -> Result<()> { let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY); - let outgoing = OutgoingMessageSender::new(tx); + let outgoing = Arc::new(OutgoingMessageSender::new(tx)); + let outgoing = ThreadScopedOutgoingMessageSender::new(outgoing, vec![ConnectionId(1)]); let conversation_id = ThreadId::new(); handle_turn_diff( diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index 5e0e32dad76..67a0a71f6ae 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -7,6 +7,7 @@ use crate::outgoing_message::ConnectionId; use crate::outgoing_message::ConnectionRequestId; use crate::outgoing_message::OutgoingMessageSender; use crate::outgoing_message::OutgoingNotification; +use crate::outgoing_message::ThreadScopedOutgoingMessageSender; use chrono::DateTime; use chrono::SecondsFormat; use chrono::Utc; @@ -251,6 +252,7 @@ use uuid::Uuid; use crate::filters::compute_source_filters; use crate::filters::source_kind_matches; +use crate::thread_state::ThreadState; use crate::thread_state::ThreadStateManager; const THREAD_LIST_DEFAULT_LIMIT: usize = 25; @@ -1961,8 +1963,9 @@ impl CodexMessageProcessor { // Auto-attach a thread listener when starting a thread. // Use the same behavior as the v1 API, with opt-in support for raw item events. if let Err(err) = self - .attach_conversation_listener( + .ensure_conversation_listener( thread_id, + request_id.connection_id, experimental_raw_events, ApiVersion::V2, ) @@ -2628,20 +2631,28 @@ impl CodexMessageProcessor { self.thread_manager.subscribe_thread_created() } - /// Best-effort: attach a listener for thread_id if missing. - pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) { - if self.thread_state_manager.has_listener_for_thread(thread_id) { - return; - } + pub(crate) async fn connection_closed(&mut self, connection_id: ConnectionId) { + self.thread_state_manager + .remove_connection(connection_id) + .await; + } - if let Err(err) = self - .attach_conversation_listener(thread_id, false, ApiVersion::V2) - .await - { - warn!( - "failed to attach listener for thread {thread_id}: {message}", - message = err.message - ); + /// Best-effort: ensure initialized connections are subscribed to this thread. + pub(crate) async fn try_attach_thread_listener( + &mut self, + thread_id: ThreadId, + connection_ids: Vec, + ) { + for connection_id in connection_ids { + if let Err(err) = self + .ensure_conversation_listener(thread_id, connection_id, false, ApiVersion::V2) + .await + { + warn!( + "failed to auto-attach listener for thread {thread_id}: {message}", + message = err.message + ); + } } } @@ -2793,7 +2804,12 @@ impl CodexMessageProcessor { }; // Auto-attach a thread listener when resuming a thread. if let Err(err) = self - .attach_conversation_listener(thread_id, false, ApiVersion::V2) + .ensure_conversation_listener( + thread_id, + request_id.connection_id, + false, + ApiVersion::V2, + ) .await { tracing::warn!( @@ -3019,7 +3035,12 @@ impl CodexMessageProcessor { }; // Auto-attach a conversation listener when forking a thread. if let Err(err) = self - .attach_conversation_listener(thread_id, false, ApiVersion::V2) + .ensure_conversation_listener( + thread_id, + request_id.connection_id, + false, + ApiVersion::V2, + ) .await { tracing::warn!( @@ -5136,7 +5157,12 @@ impl CodexMessageProcessor { })?; if let Err(err) = self - .attach_conversation_listener(thread_id, false, ApiVersion::V2) + .ensure_conversation_listener( + thread_id, + request_id.connection_id, + false, + ApiVersion::V2, + ) .await { tracing::warn!( @@ -5281,18 +5307,38 @@ impl CodexMessageProcessor { conversation_id, experimental_raw_events, } = params; - match self - .attach_conversation_listener(conversation_id, experimental_raw_events, ApiVersion::V1) - .await - { - Ok(subscription_id) => { - let response = AddConversationSubscriptionResponse { subscription_id }; - self.outgoing.send_response(request_id, response).await; - } - Err(err) => { - self.outgoing.send_error(request_id, err).await; + let conversation = match self.thread_manager.get_thread(conversation_id).await { + Ok(conv) => conv, + Err(_) => { + let error = JSONRPCErrorError { + code: INVALID_REQUEST_ERROR_CODE, + message: format!("thread not found: {conversation_id}"), + data: None, + }; + self.outgoing.send_error(request_id, error).await; + return; } - } + }; + let subscription_id = Uuid::new_v4(); + let thread_state = self + .thread_state_manager + .set_listener( + subscription_id, + conversation_id, + request_id.connection_id, + experimental_raw_events, + ) + .await; + self.ensure_listener_task_running( + conversation_id, + conversation, + thread_state, + ApiVersion::V1, + ) + .await; + + let response = AddConversationSubscriptionResponse { subscription_id }; + self.outgoing.send_response(request_id, response).await; } async fn remove_thread_listener( @@ -5322,12 +5368,13 @@ impl CodexMessageProcessor { } } - async fn attach_conversation_listener( + async fn ensure_conversation_listener( &mut self, conversation_id: ThreadId, + connection_id: ConnectionId, raw_events_enabled: bool, api_version: ApiVersion, - ) -> Result { + ) -> Result<(), JSONRPCErrorError> { let conversation = match self.thread_manager.get_thread(conversation_id).await { Ok(conv) => conv, Err(_) => { @@ -5338,13 +5385,30 @@ impl CodexMessageProcessor { }); } }; - - let subscription_id = Uuid::new_v4(); - let (cancel_tx, mut cancel_rx) = oneshot::channel(); let thread_state = self .thread_state_manager - .set_listener(subscription_id, conversation_id, cancel_tx) + .ensure_connection_subscribed(conversation_id, connection_id, raw_events_enabled) .await; + self.ensure_listener_task_running(conversation_id, conversation, thread_state, api_version) + .await; + Ok(()) + } + + async fn ensure_listener_task_running( + &self, + conversation_id: ThreadId, + conversation: Arc, + thread_state: Arc>, + api_version: ApiVersion, + ) { + let (cancel_tx, mut cancel_rx) = oneshot::channel(); + { + let mut thread_state = thread_state.lock().await; + if thread_state.listener_matches(&conversation) { + return; + } + thread_state.set_listener(cancel_tx, &conversation); + } let outgoing_for_task = self.outgoing.clone(); let fallback_model_provider = self.config.model_provider_id.clone(); tokio::spawn(async move { @@ -5363,10 +5427,6 @@ impl CodexMessageProcessor { } }; - if let EventMsg::RawResponseItem(_) = &event.msg && !raw_events_enabled { - continue; - } - // For now, we send a notification for every event, // JSON-serializing the `Event` as-is, but these should // be migrated to be variants of `ServerNotification` @@ -5391,19 +5451,38 @@ impl CodexMessageProcessor { "conversationId".to_string(), conversation_id.to_string().into(), ); + let (subscribed_connection_ids, raw_events_enabled) = { + let thread_state = thread_state.lock().await; + ( + thread_state.subscribed_connection_ids(), + thread_state.experimental_raw_events, + ) + }; + if let EventMsg::RawResponseItem(_) = &event.msg && !raw_events_enabled { + continue; + } - outgoing_for_task - .send_notification(OutgoingNotification { - method: format!("codex/event/{event_formatted}"), - params: Some(params.into()), - }) - .await; + if !subscribed_connection_ids.is_empty() { + outgoing_for_task + .send_notification_to_connections( + &subscribed_connection_ids, + OutgoingNotification { + method: format!("codex/event/{event_formatted}"), + params: Some(params.into()), + }, + ) + .await; + } + let thread_outgoing = ThreadScopedOutgoingMessageSender::new( + outgoing_for_task.clone(), + subscribed_connection_ids, + ); apply_bespoke_event_handling( event.clone(), conversation_id, conversation.clone(), - outgoing_for_task.clone(), + thread_outgoing, thread_state.clone(), api_version, fallback_model_provider.clone(), @@ -5413,9 +5492,7 @@ impl CodexMessageProcessor { } } }); - Ok(subscription_id) } - async fn git_diff_to_origin(&self, request_id: ConnectionRequestId, cwd: PathBuf) { let diff = git_diff_to_remote(&cwd).await; match diff { @@ -6299,22 +6376,137 @@ mod tests { let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?; let listener_a = Uuid::new_v4(); let listener_b = Uuid::new_v4(); - let (cancel_a, cancel_rx_a) = oneshot::channel(); - let (cancel_b, mut cancel_rx_b) = oneshot::channel(); + let connection_a = ConnectionId(1); + let connection_b = ConnectionId(2); + let (cancel_tx, mut cancel_rx) = oneshot::channel(); - manager.set_listener(listener_a, thread_id, cancel_a).await; - manager.set_listener(listener_b, thread_id, cancel_b).await; + manager + .set_listener(listener_a, thread_id, connection_a, false) + .await; + manager + .set_listener(listener_b, thread_id, connection_b, false) + .await; + { + let state = manager.thread_state(thread_id); + state.lock().await.cancel_tx = Some(cancel_tx); + } assert_eq!(manager.remove_listener(listener_a).await, Some(thread_id)); - assert_eq!(cancel_rx_a.await, Ok(())); assert!( - tokio::time::timeout(Duration::from_millis(20), &mut cancel_rx_b) + tokio::time::timeout(Duration::from_millis(20), &mut cancel_rx) .await .is_err() ); - assert_eq!(manager.remove_listener(listener_b).await, Some(thread_id)); - assert_eq!(cancel_rx_b.await, Ok(())); + assert_eq!(cancel_rx.await, Ok(())); + Ok(()) + } + + #[tokio::test] + async fn removing_listener_unsubscribes_its_connection() -> Result<()> { + let mut manager = ThreadStateManager::new(); + let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?; + let listener_a = Uuid::new_v4(); + let listener_b = Uuid::new_v4(); + let connection_a = ConnectionId(1); + let connection_b = ConnectionId(2); + + manager + .set_listener(listener_a, thread_id, connection_a, false) + .await; + manager + .set_listener(listener_b, thread_id, connection_b, false) + .await; + + assert_eq!(manager.remove_listener(listener_a).await, Some(thread_id)); + let state = manager.thread_state(thread_id); + let subscribed_connection_ids = state.lock().await.subscribed_connection_ids(); + assert_eq!(subscribed_connection_ids, vec![connection_b]); + Ok(()) + } + + #[tokio::test] + async fn set_listener_uses_last_write_for_raw_events() -> Result<()> { + let mut manager = ThreadStateManager::new(); + let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?; + let listener_a = Uuid::new_v4(); + let listener_b = Uuid::new_v4(); + let connection_a = ConnectionId(1); + let connection_b = ConnectionId(2); + + manager + .set_listener(listener_a, thread_id, connection_a, true) + .await; + { + let state = manager.thread_state(thread_id); + assert!(state.lock().await.experimental_raw_events); + } + manager + .set_listener(listener_b, thread_id, connection_b, false) + .await; + let state = manager.thread_state(thread_id); + assert!(!state.lock().await.experimental_raw_events); + Ok(()) + } + + #[tokio::test] + async fn removing_connection_clears_subscription_and_listener_when_last_subscriber() + -> Result<()> { + let mut manager = ThreadStateManager::new(); + let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?; + let listener = Uuid::new_v4(); + let connection = ConnectionId(1); + let (cancel_tx, cancel_rx) = oneshot::channel(); + + manager + .set_listener(listener, thread_id, connection, false) + .await; + { + let state = manager.thread_state(thread_id); + state.lock().await.cancel_tx = Some(cancel_tx); + } + + manager.remove_connection(connection).await; + assert_eq!(cancel_rx.await, Ok(())); + assert_eq!(manager.remove_listener(listener).await, None); + + let state = manager.thread_state(thread_id); + assert!(state.lock().await.subscribed_connection_ids().is_empty()); + Ok(()) + } + + #[tokio::test] + async fn removing_auto_attached_connection_preserves_listener_for_other_connections() + -> Result<()> { + let mut manager = ThreadStateManager::new(); + let thread_id = ThreadId::from_string("ad7f0408-99b8-4f6e-a46f-bd0eec433370")?; + let connection_a = ConnectionId(1); + let connection_b = ConnectionId(2); + let (cancel_tx, mut cancel_rx) = oneshot::channel(); + + manager + .ensure_connection_subscribed(thread_id, connection_a, false) + .await; + manager + .ensure_connection_subscribed(thread_id, connection_b, false) + .await; + { + let state = manager.thread_state(thread_id); + state.lock().await.cancel_tx = Some(cancel_tx); + } + + manager.remove_connection(connection_a).await; + assert!( + tokio::time::timeout(Duration::from_millis(20), &mut cancel_rx) + .await + .is_err() + ); + + let state = manager.thread_state(thread_id); + assert_eq!( + state.lock().await.subscribed_connection_ids(), + vec![connection_b] + ); Ok(()) } } diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 06595d5b39c..31bc831ca18 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -27,7 +27,6 @@ use crate::transport::CHANNEL_CAPACITY; use crate::transport::ConnectionState; use crate::transport::OutboundConnectionState; use crate::transport::TransportEvent; -use crate::transport::has_initialized_connections; use crate::transport::route_outgoing_envelope; use crate::transport::start_stdio_connection; use crate::transport::start_websocket_acceptor; @@ -490,6 +489,7 @@ pub async fn run_main_with_transport( { break; } + processor.connection_closed(connection_id).await; connections.remove(&connection_id); if shutdown_when_no_connections && connections.is_empty() { break; @@ -544,8 +544,19 @@ pub async fn run_main_with_transport( created = thread_created_rx.recv(), if listen_for_threads => { match created { Ok(thread_id) => { - if has_initialized_connections(&connections) { - processor.try_attach_thread_listener(thread_id).await; + let initialized_connection_ids: Vec = connections + .iter() + .filter_map(|(connection_id, connection_state)| { + connection_state.session.initialized.then_some(*connection_id) + }) + .collect(); + if !initialized_connection_ids.is_empty() { + processor + .try_attach_thread_listener( + thread_id, + initialized_connection_ids, + ) + .await; } } Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => { diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index e7d4a0b6198..f4b949ed28a 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -396,9 +396,19 @@ impl MessageProcessor { } } - pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) { + pub(crate) async fn try_attach_thread_listener( + &mut self, + thread_id: ThreadId, + connection_ids: Vec, + ) { + self.codex_message_processor + .try_attach_thread_listener(thread_id, connection_ids) + .await; + } + + pub(crate) async fn connection_closed(&mut self, connection_id: ConnectionId) { self.codex_message_processor - .try_attach_thread_listener(thread_id) + .connection_closed(connection_id) .await; } diff --git a/codex-rs/app-server/src/outgoing_message.rs b/codex-rs/app-server/src/outgoing_message.rs index 9740393efda..45f33b49d93 100644 --- a/codex-rs/app-server/src/outgoing_message.rs +++ b/codex-rs/app-server/src/outgoing_message.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; @@ -48,6 +49,62 @@ pub(crate) struct OutgoingMessageSender { request_id_to_callback: Mutex>>, } +#[derive(Clone)] +pub(crate) struct ThreadScopedOutgoingMessageSender { + outgoing: Arc, + connection_ids: Arc>, +} + +impl ThreadScopedOutgoingMessageSender { + pub(crate) fn new( + outgoing: Arc, + connection_ids: Vec, + ) -> Self { + Self { + outgoing, + connection_ids: Arc::new(connection_ids), + } + } + + pub(crate) async fn send_request( + &self, + payload: ServerRequestPayload, + ) -> oneshot::Receiver { + if self.connection_ids.is_empty() { + let (_tx, rx) = oneshot::channel(); + return rx; + } + self.outgoing + .send_request_to_connections(self.connection_ids.as_slice(), payload) + .await + } + + pub(crate) async fn send_server_notification(&self, notification: ServerNotification) { + if self.connection_ids.is_empty() { + return; + } + self.outgoing + .send_server_notification_to_connections(self.connection_ids.as_slice(), notification) + .await; + } + + pub(crate) async fn send_response( + &self, + request_id: ConnectionRequestId, + response: T, + ) { + self.outgoing.send_response(request_id, response).await; + } + + pub(crate) async fn send_error( + &self, + request_id: ConnectionRequestId, + error: JSONRPCErrorError, + ) { + self.outgoing.send_error(request_id, error).await; + } +} + impl OutgoingMessageSender { pub(crate) fn new(sender: mpsc::Sender) -> Self { Self { @@ -57,17 +114,28 @@ impl OutgoingMessageSender { } } - pub(crate) async fn send_request( + pub(crate) async fn send_request_to_connections( &self, + connection_ids: &[ConnectionId], request: ServerRequestPayload, ) -> oneshot::Receiver { - let (_id, rx) = self.send_request_with_id(request).await; + let (_id, rx) = self + .send_request_with_id_to_connections(connection_ids, request) + .await; rx } pub(crate) async fn send_request_with_id( &self, request: ServerRequestPayload, + ) -> (RequestId, oneshot::Receiver) { + self.send_request_with_id_to_connections(&[], request).await + } + + async fn send_request_with_id_to_connections( + &self, + connection_ids: &[ConnectionId], + request: ServerRequestPayload, ) -> (RequestId, oneshot::Receiver) { let id = RequestId::Integer(self.next_server_request_id.fetch_add(1, Ordering::Relaxed)); let outgoing_message_id = id.clone(); @@ -79,13 +147,34 @@ impl OutgoingMessageSender { let outgoing_message = OutgoingMessage::Request(request.request_with_id(outgoing_message_id.clone())); - if let Err(err) = self - .sender - .send(OutgoingEnvelope::Broadcast { - message: outgoing_message, - }) - .await - { + let send_result = if connection_ids.is_empty() { + self.sender + .send(OutgoingEnvelope::Broadcast { + message: outgoing_message, + }) + .await + } else { + let mut send_error = None; + for connection_id in connection_ids { + if let Err(err) = self + .sender + .send(OutgoingEnvelope::ToConnection { + connection_id: *connection_id, + message: outgoing_message.clone(), + }) + .await + { + send_error = Some(err); + break; + } + } + match send_error { + Some(err) => Err(err), + None => Ok(()), + } + }; + + if let Err(err) = send_result { warn!("failed to send request {outgoing_message_id:?} to client: {err:?}"); let mut request_id_to_callback = self.request_id_to_callback.lock().await; request_id_to_callback.remove(&outgoing_message_id); @@ -172,29 +261,71 @@ impl OutgoingMessageSender { } pub(crate) async fn send_server_notification(&self, notification: ServerNotification) { - if let Err(err) = self - .sender - .send(OutgoingEnvelope::Broadcast { - message: OutgoingMessage::AppServerNotification(notification), - }) - .await - { - warn!("failed to send server notification to client: {err:?}"); + self.send_server_notification_to_connections(&[], notification) + .await; + } + + pub(crate) async fn send_server_notification_to_connections( + &self, + connection_ids: &[ConnectionId], + notification: ServerNotification, + ) { + let outgoing_message = OutgoingMessage::AppServerNotification(notification); + if connection_ids.is_empty() { + if let Err(err) = self + .sender + .send(OutgoingEnvelope::Broadcast { + message: outgoing_message, + }) + .await + { + warn!("failed to send server notification to client: {err:?}"); + } + return; + } + for connection_id in connection_ids { + if let Err(err) = self + .sender + .send(OutgoingEnvelope::ToConnection { + connection_id: *connection_id, + message: outgoing_message.clone(), + }) + .await + { + warn!("failed to send server notification to client: {err:?}"); + } } } - /// All notifications should be migrated to [`ServerNotification`] and - /// [`OutgoingMessage::Notification`] should be removed. - pub(crate) async fn send_notification(&self, notification: OutgoingNotification) { + pub(crate) async fn send_notification_to_connections( + &self, + connection_ids: &[ConnectionId], + notification: OutgoingNotification, + ) { let outgoing_message = OutgoingMessage::Notification(notification); - if let Err(err) = self - .sender - .send(OutgoingEnvelope::Broadcast { - message: outgoing_message, - }) - .await - { - warn!("failed to send notification to client: {err:?}"); + if connection_ids.is_empty() { + if let Err(err) = self + .sender + .send(OutgoingEnvelope::Broadcast { + message: outgoing_message, + }) + .await + { + warn!("failed to send notification to client: {err:?}"); + } + return; + } + for connection_id in connection_ids { + if let Err(err) = self + .sender + .send(OutgoingEnvelope::ToConnection { + connection_id: *connection_id, + message: outgoing_message.clone(), + }) + .await + { + warn!("failed to send notification to client: {err:?}"); + } } } diff --git a/codex-rs/app-server/src/thread_state.rs b/codex-rs/app-server/src/thread_state.rs index eb263c5e4f2..2ccb7eaae75 100644 --- a/codex-rs/app-server/src/thread_state.rs +++ b/codex-rs/app-server/src/thread_state.rs @@ -1,9 +1,12 @@ +use crate::outgoing_message::ConnectionId; use crate::outgoing_message::ConnectionRequestId; use codex_app_server_protocol::TurnError; +use codex_core::CodexThread; use codex_protocol::ThreadId; use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; +use std::sync::Weak; use tokio::sync::Mutex; use tokio::sync::oneshot; use uuid::Uuid; @@ -25,33 +28,66 @@ pub(crate) struct ThreadState { pub(crate) pending_interrupts: PendingInterruptQueue, pub(crate) pending_rollbacks: Option, pub(crate) turn_summary: TurnSummary, - pub(crate) listener_cancel_txs: HashMap>, + pub(crate) cancel_tx: Option>, + pub(crate) experimental_raw_events: bool, + listener_thread: Option>, + subscribed_connections: HashSet, } impl ThreadState { - fn set_listener(&mut self, subscription_id: Uuid, cancel_tx: oneshot::Sender<()>) { - if let Some(previous) = self.listener_cancel_txs.insert(subscription_id, cancel_tx) { + pub(crate) fn listener_matches(&self, conversation: &Arc) -> bool { + self.listener_thread + .as_ref() + .and_then(Weak::upgrade) + .is_some_and(|existing| Arc::ptr_eq(&existing, conversation)) + } + + pub(crate) fn set_listener( + &mut self, + cancel_tx: oneshot::Sender<()>, + conversation: &Arc, + ) { + if let Some(previous) = self.cancel_tx.replace(cancel_tx) { let _ = previous.send(()); } + self.listener_thread = Some(Arc::downgrade(conversation)); } - fn clear_listener(&mut self, subscription_id: Uuid) { - if let Some(cancel_tx) = self.listener_cancel_txs.remove(&subscription_id) { + pub(crate) fn clear_listener(&mut self) { + if let Some(cancel_tx) = self.cancel_tx.take() { let _ = cancel_tx.send(()); } + self.listener_thread = None; } - fn clear_listeners(&mut self) { - for (_, cancel_tx) in self.listener_cancel_txs.drain() { - let _ = cancel_tx.send(()); - } + pub(crate) fn add_connection(&mut self, connection_id: ConnectionId) { + self.subscribed_connections.insert(connection_id); + } + + pub(crate) fn remove_connection(&mut self, connection_id: ConnectionId) { + self.subscribed_connections.remove(&connection_id); + } + + pub(crate) fn subscribed_connection_ids(&self) -> Vec { + self.subscribed_connections.iter().copied().collect() + } + + pub(crate) fn set_experimental_raw_events(&mut self, enabled: bool) { + self.experimental_raw_events = enabled; } } +#[derive(Clone, Copy)] +struct SubscriptionState { + thread_id: ThreadId, + connection_id: ConnectionId, +} + #[derive(Default)] pub(crate) struct ThreadStateManager { thread_states: HashMap>>, - thread_id_by_subscription: HashMap, + subscription_state_by_id: HashMap, + thread_ids_by_connection: HashMap>, } impl ThreadStateManager { @@ -59,12 +95,6 @@ impl ThreadStateManager { Self::default() } - pub(crate) fn has_listener_for_thread(&self, thread_id: ThreadId) -> bool { - self.thread_id_by_subscription - .values() - .any(|existing| *existing == thread_id) - } - pub(crate) fn thread_state(&mut self, thread_id: ThreadId) -> Arc> { self.thread_states .entry(thread_id) @@ -73,34 +103,119 @@ impl ThreadStateManager { } pub(crate) async fn remove_listener(&mut self, subscription_id: Uuid) -> Option { - let thread_id = self.thread_id_by_subscription.remove(&subscription_id)?; + let subscription_state = self.subscription_state_by_id.remove(&subscription_id)?; + let thread_id = subscription_state.thread_id; + + let connection_still_subscribed_to_thread = + self.subscription_state_by_id.values().any(|state| { + state.thread_id == thread_id + && state.connection_id == subscription_state.connection_id + }); + if !connection_still_subscribed_to_thread { + let mut remove_connection_entry = false; + if let Some(thread_ids) = self + .thread_ids_by_connection + .get_mut(&subscription_state.connection_id) + { + thread_ids.remove(&thread_id); + remove_connection_entry = thread_ids.is_empty(); + } + if remove_connection_entry { + self.thread_ids_by_connection + .remove(&subscription_state.connection_id); + } + if let Some(thread_state) = self.thread_states.get(&thread_id) { + thread_state + .lock() + .await + .remove_connection(subscription_state.connection_id); + } + } + if let Some(thread_state) = self.thread_states.get(&thread_id) { - thread_state.lock().await.clear_listener(subscription_id); + let mut thread_state = thread_state.lock().await; + if thread_state.subscribed_connection_ids().is_empty() { + thread_state.clear_listener(); + } } Some(thread_id) } pub(crate) async fn remove_thread_state(&mut self, thread_id: ThreadId) { if let Some(thread_state) = self.thread_states.remove(&thread_id) { - thread_state.lock().await.clear_listeners(); + thread_state.lock().await.clear_listener(); } - self.thread_id_by_subscription - .retain(|_, existing_thread_id| *existing_thread_id != thread_id); + self.subscription_state_by_id + .retain(|_, state| state.thread_id != thread_id); + self.thread_ids_by_connection.retain(|_, thread_ids| { + thread_ids.remove(&thread_id); + !thread_ids.is_empty() + }); } pub(crate) async fn set_listener( &mut self, subscription_id: Uuid, thread_id: ThreadId, - cancel_tx: oneshot::Sender<()>, + connection_id: ConnectionId, + experimental_raw_events: bool, ) -> Arc> { - self.thread_id_by_subscription - .insert(subscription_id, thread_id); + self.subscription_state_by_id.insert( + subscription_id, + SubscriptionState { + thread_id, + connection_id, + }, + ); + self.thread_ids_by_connection + .entry(connection_id) + .or_default() + .insert(thread_id); let thread_state = self.thread_state(thread_id); + { + let mut thread_state_guard = thread_state.lock().await; + thread_state_guard.add_connection(connection_id); + thread_state_guard.set_experimental_raw_events(experimental_raw_events); + } thread_state - .lock() - .await - .set_listener(subscription_id, cancel_tx); + } + + pub(crate) async fn ensure_connection_subscribed( + &mut self, + thread_id: ThreadId, + connection_id: ConnectionId, + experimental_raw_events: bool, + ) -> Arc> { + self.thread_ids_by_connection + .entry(connection_id) + .or_default() + .insert(thread_id); + let thread_state = self.thread_state(thread_id); + { + let mut thread_state_guard = thread_state.lock().await; + thread_state_guard.add_connection(connection_id); + if experimental_raw_events { + thread_state_guard.set_experimental_raw_events(true); + } + } thread_state } + + pub(crate) async fn remove_connection(&mut self, connection_id: ConnectionId) { + let Some(thread_ids) = self.thread_ids_by_connection.remove(&connection_id) else { + return; + }; + self.subscription_state_by_id + .retain(|_, state| state.connection_id != connection_id); + + for thread_id in thread_ids { + if let Some(thread_state) = self.thread_states.get(&thread_id) { + let mut thread_state = thread_state.lock().await; + thread_state.remove_connection(connection_id); + if thread_state.subscribed_connection_ids().is_empty() { + thread_state.clear_listener(); + } + } + } + } } diff --git a/codex-rs/app-server/src/transport.rs b/codex-rs/app-server/src/transport.rs index cbfd263a555..558d786137e 100644 --- a/codex-rs/app-server/src/transport.rs +++ b/codex-rs/app-server/src/transport.rs @@ -478,6 +478,9 @@ pub(crate) async fn route_outgoing_envelope( ); return disconnected; }; + if should_skip_notification_for_connection(connection_state, &message) { + return disconnected; + } if connection_state.writer.send(message).await.is_err() { connections.remove(&connection_id); disconnected.push(connection_id); @@ -511,14 +514,6 @@ pub(crate) async fn route_outgoing_envelope( disconnected } -pub(crate) fn has_initialized_connections( - connections: &HashMap, -) -> bool { - connections - .values() - .any(|connection| connection.session.initialized) -} - #[cfg(test)] mod tests { use super::*; @@ -746,4 +741,40 @@ mod tests { let queued_json = serde_json::to_value(queued_outgoing).expect("serialize queued message"); assert_eq!(queued_json, json!({ "method": "queued" })); } + + #[tokio::test] + async fn to_connection_notification_respects_opt_out_filters() { + let connection_id = ConnectionId(7); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + let initialized = Arc::new(AtomicBool::new(true)); + let opted_out_notification_methods = Arc::new(RwLock::new(HashSet::from([ + "codex/event/task_started".to_string(), + ]))); + + let mut connections = HashMap::new(); + connections.insert( + connection_id, + OutboundConnectionState::new(writer_tx, initialized, opted_out_notification_methods), + ); + + let disconnected = route_outgoing_envelope( + &mut connections, + OutgoingEnvelope::ToConnection { + connection_id, + message: OutgoingMessage::Notification( + crate::outgoing_message::OutgoingNotification { + method: "codex/event/task_started".to_string(), + params: None, + }, + ), + }, + ) + .await; + + assert_eq!(disconnected, Vec::::new()); + assert!( + writer_rx.try_recv().is_err(), + "opted-out notification should be dropped" + ); + } }