diff --git a/codex-rs/app-server-protocol/src/protocol/common.rs b/codex-rs/app-server-protocol/src/protocol/common.rs index dcab5eaed1..6e3fe5eb25 100644 --- a/codex-rs/app-server-protocol/src/protocol/common.rs +++ b/codex-rs/app-server-protocol/src/protocol/common.rs @@ -413,6 +413,18 @@ impl TryFrom for ServerNotification { #[strum(serialize_all = "camelCase")] pub enum ClientNotification { Initialized, + /// LSP-style cancellation of an in-flight JSON-RPC request. + /// Shape: { "method": "$/cancelRequest", "params": { "id": } } + #[serde(rename = "$/cancelRequest")] + #[ts(rename = "$/cancelRequest")] + #[strum(serialize = "$/cancelRequest")] + CancelRequest(CancelRequestParams), +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)] +#[serde(rename_all = "camelCase")] +pub struct CancelRequestParams { + pub id: RequestId, } #[cfg(test)] @@ -499,6 +511,22 @@ mod tests { Ok(()) } + #[test] + fn serialize_cancel_request_notification() -> Result<()> { + let notification = ClientNotification::CancelRequest(CancelRequestParams { + id: RequestId::Integer(7), + }); + + assert_eq!( + json!({ + "method": "$/cancelRequest", + "params": { "id": 7 } + }), + serde_json::to_value(¬ification)?, + ); + Ok(()) + } + #[test] fn serialize_server_request() -> Result<()> { let conversation_id = ConversationId::from_string("67e55044-10b1-426f-9247-bb680e5fe0c8")?; diff --git a/codex-rs/app-server/src/cancellation_registry.rs b/codex-rs/app-server/src/cancellation_registry.rs new file mode 100644 index 0000000000..01957ce463 --- /dev/null +++ b/codex-rs/app-server/src/cancellation_registry.rs @@ -0,0 +1,56 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::Mutex as StdMutex; +use std::sync::PoisonError; + +use codex_app_server_protocol::RequestId; + +trait Cancellable: Send { + fn cancel(&self); +} + +impl Cancellable for F +where + F: Fn() + Send, +{ + fn cancel(&self) { + (self)(); + } +} + +#[derive(Clone, Default)] +pub(crate) struct CancellationRegistry { + inner: Arc>>>, +} + +impl CancellationRegistry { + pub(crate) fn insert(&self, id: RequestId, f: F) + where + F: Fn() + Send + 'static, + { + self.inner + .lock() + .unwrap_or_else(PoisonError::into_inner) + .insert(id, Box::new(f)); + } + + pub(crate) fn cancel(&self, id: &RequestId) -> bool { + // Remove the callback while holding the lock, but invoke it only after + // releasing the lock to avoid deadlocks or long critical sections. + let callback = { + let mut guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner); + guard.remove(id) + }; + if let Some(c) = callback { + c.cancel(); + true + } else { + false + } + } + + pub(crate) fn remove(&self, id: &RequestId) { + let mut guard = self.inner.lock().unwrap_or_else(PoisonError::into_inner); + guard.remove(id); + } +} diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index 99310e198f..589a56559a 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -118,11 +118,14 @@ use tracing::info; use tracing::warn; use uuid::Uuid; +use crate::cancellation_registry::CancellationRegistry; + // Duration before a ChatGPT login attempt is abandoned. const LOGIN_CHATGPT_TIMEOUT: Duration = Duration::from_secs(10 * 60); struct ActiveLogin { shutdown_handle: ShutdownHandle, login_id: Uuid, + request_id: RequestId, } impl ActiveLogin { @@ -144,6 +147,7 @@ pub(crate) struct CodexMessageProcessor { pending_interrupts: Arc>>>, pending_fuzzy_searches: Arc>>>, feedback: CodexFeedback, + cancellation_registry: CancellationRegistry, } impl CodexMessageProcessor { @@ -166,6 +170,7 @@ impl CodexMessageProcessor { pending_interrupts: Arc::new(Mutex::new(HashMap::new())), pending_fuzzy_searches: Arc::new(Mutex::new(HashMap::new())), feedback, + cancellation_registry: CancellationRegistry::default(), } } @@ -391,13 +396,22 @@ impl CodexMessageProcessor { let mut guard = self.active_login.lock().await; if let Some(existing) = guard.take() { existing.drop(); + self.cancellation_registry.remove(&existing.request_id); } *guard = Some(ActiveLogin { shutdown_handle: shutdown_handle.clone(), login_id, + request_id: request_id.clone(), }); } + // Register cancellation for this request id so $/cancelRequest works. + let shutdown_for_cancel = shutdown_handle.clone(); + self.cancellation_registry + .insert(request_id.clone(), move || { + shutdown_for_cancel.shutdown(); + }); + let response = LoginChatGptResponse { login_id, auth_url: server.auth_url.clone(), @@ -407,6 +421,8 @@ impl CodexMessageProcessor { let outgoing_clone = self.outgoing.clone(); let active_login = self.active_login.clone(); let auth_manager = self.auth_manager.clone(); + let cancellation_registry = self.cancellation_registry.clone(); + let request_id_for_task = request_id.clone(); tokio::spawn(async move { let (success, error_msg) = match tokio::time::timeout( LOGIN_CHATGPT_TIMEOUT, @@ -451,6 +467,8 @@ impl CodexMessageProcessor { if guard.as_ref().map(|l| l.login_id) == Some(login_id) { *guard = None; } + + cancellation_registry.remove(&request_id_for_task); }); LoginChatGptReply::Response(response) @@ -470,6 +488,24 @@ impl CodexMessageProcessor { } } + /// Handle a generic JSON-RPC `$ /cancelRequest` for a previously started operation. + /// + /// Note: individual request handlers that wish to be cancellable must + /// register a cancellation action in the `CancellationRegistry` using the + /// original JSON-RPC `request_id` when they start work. This method looks up + /// that action by id and triggers it if found. It is fire-and-forget; no + /// JSON-RPC response is sent. + pub async fn cancel_request(&self, id: RequestId) { + let found = self.cancellation_registry.cancel(&id); + if !found { + tracing::debug!( + "$/cancelRequest for unknown or already-finished id: {:?}", + id + ); + } + } + + // Legacy endpoint for cancelling a LoginChatGpt request. Please use $/cancelRequest instead. async fn cancel_login_chatgpt(&mut self, request_id: RequestId, login_id: Uuid) { let mut guard = self.active_login.lock().await; if guard.as_ref().map(|l| l.login_id) == Some(login_id) { diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 6ef986919f..d340cb9b2f 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -28,6 +28,7 @@ use tracing_subscriber::filter::Targets; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; +mod cancellation_registry; mod codex_message_processor; mod error_code; mod fuzzy_file_search; diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 7693cc2ff9..7b00614ce9 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -3,7 +3,9 @@ use std::path::PathBuf; use crate::codex_message_processor::CodexMessageProcessor; use crate::error_code::INVALID_REQUEST_ERROR_CODE; use crate::outgoing_message::OutgoingMessageSender; +use codex_app_server_protocol::CancelRequestParams; use codex_app_server_protocol::ClientInfo; +use codex_app_server_protocol::ClientNotification; use codex_app_server_protocol::ClientRequest; use codex_app_server_protocol::InitializeResponse; @@ -125,9 +127,20 @@ impl MessageProcessor { } pub(crate) async fn process_notification(&self, notification: JSONRPCNotification) { - // Currently, we do not expect to receive any notifications from the - // client, so we just log them. tracing::info!("<- notification: {:?}", notification); + + if let Ok(value) = serde_json::to_value(¬ification) + && let Ok(typed) = serde_json::from_value::(value) + { + match typed { + ClientNotification::CancelRequest(CancelRequestParams { id }) => { + self.codex_message_processor.cancel_request(id).await; + } + ClientNotification::Initialized => { + // Already handled during handshake; ignore. + } + } + } } /// Handle a standalone JSON-RPC response originating from the peer. diff --git a/codex-rs/app-server/tests/suite/login.rs b/codex-rs/app-server/tests/suite/login.rs index c5470c3ec4..9a04c66071 100644 --- a/codex-rs/app-server/tests/suite/login.rs +++ b/codex-rs/app-server/tests/suite/login.rs @@ -3,6 +3,8 @@ use app_test_support::McpProcess; use app_test_support::to_response; use codex_app_server_protocol::CancelLoginChatGptParams; use codex_app_server_protocol::CancelLoginChatGptResponse; +use codex_app_server_protocol::CancelRequestParams; +use codex_app_server_protocol::ClientNotification; use codex_app_server_protocol::GetAuthStatusParams; use codex_app_server_protocol::GetAuthStatusResponse; use codex_app_server_protocol::JSONRPCError; @@ -10,6 +12,7 @@ use codex_app_server_protocol::JSONRPCResponse; use codex_app_server_protocol::LoginChatGptResponse; use codex_app_server_protocol::LogoutChatGptResponse; use codex_app_server_protocol::RequestId; +use codex_app_server_protocol::ServerNotification; use codex_core::auth::AuthCredentialsStoreMode; use codex_login::login_with_api_key; use serial_test::serial; @@ -204,3 +207,44 @@ async fn login_chatgpt_includes_forced_workspace_query_param() -> Result<()> { ); Ok(()) } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial(login_port)] +async fn login_chatgpt_cancelled_via_cancel_request() -> Result<()> { + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path())?; + + let mut mcp = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; + + let login_request_id = mcp.send_login_chat_gpt_request().await?; + + let login_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(login_request_id)), + ) + .await??; + let login: LoginChatGptResponse = to_response(login_resp)?; + + // Cancel the in-flight request using LSP-style $/cancelRequest. + mcp.send_notification(ClientNotification::CancelRequest(CancelRequestParams { + id: RequestId::Integer(login_request_id), + })) + .await?; + + // Expect a loginChatGptComplete notification indicating cancellation (success = false). + let note = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("loginChatGptComplete"), + ) + .await??; + + let parsed: ServerNotification = note.try_into()?; + let ServerNotification::LoginChatGptComplete(payload) = parsed else { + anyhow::bail!("expected loginChatGptComplete notification"); + }; + assert_eq!(payload.login_id, login.login_id); + assert!(!payload.success); + assert!(payload.error.is_some()); + Ok(()) +}