diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 9481950faf3..3251e81b3e4 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1348,6 +1348,7 @@ dependencies = [ "axum", "base64 0.22.1", "chrono", + "clap", "codex-app-server-protocol", "codex-arg0", "codex-backend-client", @@ -1361,10 +1362,13 @@ dependencies = [ "codex-protocol", "codex-rmcp-client", "codex-utils-absolute-path", + "codex-utils-cargo-bin", "codex-utils-cli", "codex-utils-json-to-toml", "core_test_support", + "futures", "os_info", + "owo-colors", "pretty_assertions", "rmcp", "serde", @@ -1374,6 +1378,7 @@ dependencies = [ "tempfile", "time", "tokio", + "tokio-tungstenite", "toml 0.9.12+spec-1.1.0", "tracing", "tracing-subscriber", diff --git a/codex-rs/app-server/Cargo.toml b/codex-rs/app-server/Cargo.toml index 7cd2091826d..84ace4e38d0 100644 --- a/codex-rs/app-server/Cargo.toml +++ b/codex-rs/app-server/Cargo.toml @@ -30,8 +30,12 @@ codex-protocol = { workspace = true } codex-app-server-protocol = { workspace = true } codex-feedback = { workspace = true } codex-rmcp-client = { workspace = true } +codex-utils-absolute-path = { workspace = true } codex-utils-json-to-toml = { workspace = true } chrono = { workspace = true } +clap = { workspace = true, features = ["derive"] } +futures = { workspace = true } +owo-colors = { workspace = true, features = ["supports-colors"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tempfile = { workspace = true } @@ -44,6 +48,7 @@ tokio = { workspace = true, features = [ "rt-multi-thread", "signal", ] } +tokio-tungstenite = { workspace = true } tracing = { workspace = true, features = ["log"] } tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] } uuid = { workspace = true, features = ["serde", "v7"] } @@ -57,8 +62,8 @@ axum = { workspace = true, default-features = false, features = [ ] } base64 = { workspace = true } codex-execpolicy = { workspace = true } -codex-utils-absolute-path = { workspace = true } core_test_support = { workspace = true } +codex-utils-cargo-bin = { workspace = true } os_info = { workspace = true } pretty_assertions = { workspace = true } rmcp = { workspace = true, default-features = false, features = [ @@ -66,5 +71,6 @@ rmcp = { workspace = true, default-features = false, features = [ "transport-streamable-http-server", ] } serial_test = { workspace = true } +tokio-tungstenite = { workspace = true } wiremock = { workspace = true } shlex = { workspace = true } diff --git a/codex-rs/app-server/README.md b/codex-rs/app-server/README.md index 7608262077c..b6c7f9e4f4e 100644 --- a/codex-rs/app-server/README.md +++ b/codex-rs/app-server/README.md @@ -19,7 +19,20 @@ ## Protocol -Similar to [MCP](https://modelcontextprotocol.io/), `codex app-server` supports bidirectional communication, streaming JSONL over stdio. The protocol is JSON-RPC 2.0, though the `"jsonrpc":"2.0"` header is omitted. +Similar to [MCP](https://modelcontextprotocol.io/), `codex app-server` supports bidirectional communication using JSON-RPC 2.0 messages (with the `"jsonrpc":"2.0"` header omitted on the wire). + +Supported transports: + +- stdio (`--listen stdio://`, default): newline-delimited JSON (JSONL) +- websocket (`--listen ws://IP:PORT`): one JSON-RPC message per websocket text frame (**experimental / unsupported**) + +Websocket transport is currently experimental and unsupported. Do not rely on it for production workloads. + +Backpressure behavior: + +- The server uses bounded queues between transport ingress, request processing, and outbound writes. +- When request ingress is saturated, new requests are rejected with a JSON-RPC error code `-32001` and message `"Server overloaded; retry later."`. +- Clients should treat this as retryable and use exponential backoff with jitter. ## Message Schema @@ -42,7 +55,7 @@ Use the thread APIs to create, list, or archive conversations. Drive a conversat ## Lifecycle Overview -- Initialize once: Immediately after launching the codex app-server process, send an `initialize` request with your client metadata, then emit an `initialized` notification. Any other request before this handshake gets rejected. +- Initialize once per connection: Immediately after opening a transport connection, send an `initialize` request with your client metadata, then emit an `initialized` notification. Any other request on that connection before this handshake gets rejected. - Start (or resume) a thread: Call `thread/start` to open a fresh conversation. The response returns the thread object and you’ll also get a `thread/started` notification. If you’re continuing an existing conversation, call `thread/resume` with its ID instead. If you want to branch from an existing conversation, call `thread/fork` to create a new thread id with copied history. - Begin a turn: To send user input, call `turn/start` with the target `threadId` and the user's input. Optional fields let you override model, cwd, sandbox policy, etc. This immediately returns the new turn object and triggers a `turn/started` notification. - Stream events: After `turn/start`, keep reading JSON-RPC notifications on stdout. You’ll see `item/started`, `item/completed`, deltas like `item/agentMessage/delta`, tool progress, etc. These represent streaming model output plus any side effects (commands, tool calls, reasoning notes). @@ -50,7 +63,7 @@ Use the thread APIs to create, list, or archive conversations. Drive a conversat ## Initialization -Clients must send a single `initialize` request before invoking any other method, then acknowledge with an `initialized` notification. The server returns the user agent string it will present to upstream services; subsequent requests issued before initialization receive a `"Not initialized"` error, and repeated `initialize` calls receive an `"Already initialized"` error. +Clients must send a single `initialize` request per transport connection before invoking any other method on that connection, then acknowledge with an `initialized` notification. The server returns the user agent string it will present to upstream services; subsequent requests issued before initialization receive a `"Not initialized"` error, and repeated `initialize` calls on the same connection receive an `"Already initialized"` error. `initialize.params.capabilities` also supports per-connection notification opt-out via `optOutNotificationMethods`, which is a list of exact method names to suppress for that connection. Matching is exact (no wildcards/prefixes). Unknown method names are accepted and ignored. diff --git a/codex-rs/app-server/src/bespoke_event_handling.rs b/codex-rs/app-server/src/bespoke_event_handling.rs index 7d9b15db7b2..cfa286b45c8 100644 --- a/codex-rs/app-server/src/bespoke_event_handling.rs +++ b/codex-rs/app-server/src/bespoke_event_handling.rs @@ -1115,7 +1115,7 @@ pub(crate) async fn apply_bespoke_event_handling( ), data: None, }; - outgoing.send_error(request_id, error).await; + outgoing.send_error(request_id.clone(), error).await; return; } } @@ -1129,7 +1129,7 @@ pub(crate) async fn apply_bespoke_event_handling( ), data: None, }; - outgoing.send_error(request_id, error).await; + outgoing.send_error(request_id.clone(), error).await; return; } }; @@ -1894,6 +1894,7 @@ async fn construct_mcp_tool_call_end_notification( mod tests { use super::*; use crate::CHANNEL_CAPACITY; + use crate::outgoing_message::OutgoingEnvelope; use crate::outgoing_message::OutgoingMessage; use crate::outgoing_message::OutgoingMessageSender; use anyhow::Result; @@ -1923,6 +1924,21 @@ mod tests { Arc::new(Mutex::new(HashMap::new())) } + async fn recv_broadcast_message( + rx: &mut mpsc::Receiver, + ) -> Result { + let envelope = rx + .recv() + .await + .ok_or_else(|| anyhow!("should send one message"))?; + match envelope { + OutgoingEnvelope::Broadcast { message } => Ok(message), + OutgoingEnvelope::ToConnection { connection_id, .. } => { + bail!("unexpected targeted message for connection {connection_id:?}") + } + } + } + #[test] fn file_change_accept_for_session_maps_to_approved_for_session() { let (decision, completion_status) = @@ -2024,10 +2040,7 @@ mod tests { ) .await; - let msg = rx - .recv() - .await - .ok_or_else(|| anyhow!("should send one notification"))?; + let msg = recv_broadcast_message(&mut rx).await?; match msg { OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { assert_eq!(n.turn.id, event_turn_id); @@ -2066,10 +2079,7 @@ mod tests { ) .await; - let msg = rx - .recv() - .await - .ok_or_else(|| anyhow!("should send one notification"))?; + let msg = recv_broadcast_message(&mut rx).await?; match msg { OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { assert_eq!(n.turn.id, event_turn_id); @@ -2108,10 +2118,7 @@ mod tests { ) .await; - let msg = rx - .recv() - .await - .ok_or_else(|| anyhow!("should send one notification"))?; + let msg = recv_broadcast_message(&mut rx).await?; match msg { OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { assert_eq!(n.turn.id, event_turn_id); @@ -2160,10 +2167,7 @@ mod tests { ) .await; - let msg = rx - .recv() - .await - .ok_or_else(|| anyhow!("should send one notification"))?; + let msg = recv_broadcast_message(&mut rx).await?; match msg { OutgoingMessage::AppServerNotification(ServerNotification::TurnPlanUpdated(n)) => { assert_eq!(n.thread_id, conversation_id.to_string()); @@ -2233,10 +2237,7 @@ mod tests { ) .await; - let first = rx - .recv() - .await - .ok_or_else(|| anyhow!("expected usage notification"))?; + let first = recv_broadcast_message(&mut rx).await?; match first { OutgoingMessage::AppServerNotification( ServerNotification::ThreadTokenUsageUpdated(payload), @@ -2252,10 +2253,7 @@ mod tests { other => bail!("unexpected notification: {other:?}"), } - let second = rx - .recv() - .await - .ok_or_else(|| anyhow!("expected rate limit notification"))?; + let second = recv_broadcast_message(&mut rx).await?; match second { OutgoingMessage::AppServerNotification( ServerNotification::AccountRateLimitsUpdated(payload), @@ -2394,10 +2392,7 @@ mod tests { .await; // Verify: A turn 1 - let msg = rx - .recv() - .await - .ok_or_else(|| anyhow!("should send first notification"))?; + let msg = recv_broadcast_message(&mut rx).await?; match msg { OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { assert_eq!(n.turn.id, a_turn1); @@ -2415,10 +2410,7 @@ mod tests { } // Verify: B turn 1 - let msg = rx - .recv() - .await - .ok_or_else(|| anyhow!("should send second notification"))?; + let msg = recv_broadcast_message(&mut rx).await?; match msg { OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { assert_eq!(n.turn.id, b_turn1); @@ -2436,10 +2428,7 @@ mod tests { } // Verify: A turn 2 - let msg = rx - .recv() - .await - .ok_or_else(|| anyhow!("should send third notification"))?; + let msg = recv_broadcast_message(&mut rx).await?; match msg { OutgoingMessage::AppServerNotification(ServerNotification::TurnCompleted(n)) => { assert_eq!(n.turn.id, a_turn2); @@ -2605,10 +2594,7 @@ mod tests { ) .await; - let msg = rx - .recv() - .await - .ok_or_else(|| anyhow!("should send one notification"))?; + let msg = recv_broadcast_message(&mut rx).await?; match msg { OutgoingMessage::AppServerNotification(ServerNotification::TurnDiffUpdated( notification, diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index cd1c8416c43..f3ed153b524 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -3,6 +3,8 @@ use crate::error_code::INTERNAL_ERROR_CODE; use crate::error_code::INVALID_REQUEST_ERROR_CODE; use crate::fuzzy_file_search::run_fuzzy_file_search; use crate::models::supported_models; +use crate::outgoing_message::ConnectionId; +use crate::outgoing_message::ConnectionRequestId; use crate::outgoing_message::OutgoingMessageSender; use crate::outgoing_message::OutgoingNotification; use chrono::DateTime; @@ -83,7 +85,6 @@ use codex_app_server_protocol::NewConversationParams; use codex_app_server_protocol::NewConversationResponse; use codex_app_server_protocol::RemoveConversationListenerParams; use codex_app_server_protocol::RemoveConversationSubscriptionResponse; -use codex_app_server_protocol::RequestId; use codex_app_server_protocol::ResumeConversationParams; use codex_app_server_protocol::ResumeConversationResponse; use codex_app_server_protocol::ReviewDelivery as ApiReviewDelivery; @@ -252,10 +253,10 @@ use uuid::Uuid; use crate::filters::compute_source_filters; use crate::filters::source_kind_matches; -type PendingInterruptQueue = Vec<(RequestId, ApiVersion)>; +type PendingInterruptQueue = Vec<(ConnectionRequestId, ApiVersion)>; pub(crate) type PendingInterrupts = Arc>>; -pub(crate) type PendingRollbacks = Arc>>; +pub(crate) type PendingRollbacks = Arc>>; /// Per-conversation accumulation of the latest states e.g. error message while a turn runs. #[derive(Default, Clone)] @@ -486,103 +487,137 @@ impl CodexMessageProcessor { Ok((review_request, hint)) } - pub async fn process_request(&mut self, request: ClientRequest) { + pub async fn process_request(&mut self, connection_id: ConnectionId, request: ClientRequest) { + let to_connection_request_id = |request_id| ConnectionRequestId { + connection_id, + request_id, + }; + match request { ClientRequest::Initialize { .. } => { panic!("Initialize should be handled in MessageProcessor"); } // === v2 Thread/Turn APIs === ClientRequest::ThreadStart { request_id, params } => { - self.thread_start(request_id, params).await; + self.thread_start(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadResume { request_id, params } => { - self.thread_resume(request_id, params).await; + self.thread_resume(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadFork { request_id, params } => { - self.thread_fork(request_id, params).await; + self.thread_fork(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadArchive { request_id, params } => { - self.thread_archive(request_id, params).await; + self.thread_archive(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadSetName { request_id, params } => { - self.thread_set_name(request_id, params).await; + self.thread_set_name(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadUnarchive { request_id, params } => { - self.thread_unarchive(request_id, params).await; + self.thread_unarchive(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadCompactStart { request_id, params } => { - self.thread_compact_start(request_id, params).await; + self.thread_compact_start(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadBackgroundTerminalsClean { request_id, params } => { - self.thread_background_terminals_clean(request_id, params) - .await; + self.thread_background_terminals_clean( + to_connection_request_id(request_id), + params, + ) + .await; } ClientRequest::ThreadRollback { request_id, params } => { - self.thread_rollback(request_id, params).await; + self.thread_rollback(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadList { request_id, params } => { - self.thread_list(request_id, params).await; + self.thread_list(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadLoadedList { request_id, params } => { - self.thread_loaded_list(request_id, params).await; + self.thread_loaded_list(to_connection_request_id(request_id), params) + .await; } ClientRequest::ThreadRead { request_id, params } => { - self.thread_read(request_id, params).await; + self.thread_read(to_connection_request_id(request_id), params) + .await; } ClientRequest::SkillsList { request_id, params } => { - self.skills_list(request_id, params).await; + self.skills_list(to_connection_request_id(request_id), params) + .await; } ClientRequest::SkillsRemoteRead { request_id, params } => { - self.skills_remote_read(request_id, params).await; + self.skills_remote_read(to_connection_request_id(request_id), params) + .await; } ClientRequest::SkillsRemoteWrite { request_id, params } => { - self.skills_remote_write(request_id, params).await; + self.skills_remote_write(to_connection_request_id(request_id), params) + .await; } ClientRequest::AppsList { request_id, params } => { - self.apps_list(request_id, params).await; + self.apps_list(to_connection_request_id(request_id), params) + .await; } ClientRequest::SkillsConfigWrite { request_id, params } => { - self.skills_config_write(request_id, params).await; + self.skills_config_write(to_connection_request_id(request_id), params) + .await; } ClientRequest::TurnStart { request_id, params } => { - self.turn_start(request_id, params).await; + self.turn_start(to_connection_request_id(request_id), params) + .await; } ClientRequest::TurnSteer { request_id, params } => { - self.turn_steer(request_id, params).await; + self.turn_steer(to_connection_request_id(request_id), params) + .await; } ClientRequest::TurnInterrupt { request_id, params } => { - self.turn_interrupt(request_id, params).await; + self.turn_interrupt(to_connection_request_id(request_id), params) + .await; } ClientRequest::ReviewStart { request_id, params } => { - self.review_start(request_id, params).await; + self.review_start(to_connection_request_id(request_id), params) + .await; } ClientRequest::NewConversation { request_id, params } => { // Do not tokio::spawn() to process new_conversation() // asynchronously because we need to ensure the conversation is // created before processing any subsequent messages. - self.process_new_conversation(request_id, params).await; + self.process_new_conversation(to_connection_request_id(request_id), params) + .await; } ClientRequest::GetConversationSummary { request_id, params } => { - self.get_thread_summary(request_id, params).await; + self.get_thread_summary(to_connection_request_id(request_id), params) + .await; } ClientRequest::ListConversations { request_id, params } => { - self.handle_list_conversations(request_id, params).await; + self.handle_list_conversations(to_connection_request_id(request_id), params) + .await; } ClientRequest::ModelList { request_id, params } => { let outgoing = self.outgoing.clone(); let thread_manager = self.thread_manager.clone(); let config = self.config.clone(); + let request_id = to_connection_request_id(request_id); tokio::spawn(async move { Self::list_models(outgoing, thread_manager, config, request_id, params).await; }); } ClientRequest::ExperimentalFeatureList { request_id, params } => { - self.experimental_feature_list(request_id, params).await; + self.experimental_feature_list(to_connection_request_id(request_id), params) + .await; } ClientRequest::CollaborationModeList { request_id, params } => { let outgoing = self.outgoing.clone(); let thread_manager = self.thread_manager.clone(); + let request_id = to_connection_request_id(request_id); tokio::spawn(async move { Self::list_collaboration_modes(outgoing, thread_manager, request_id, params) @@ -590,109 +625,136 @@ impl CodexMessageProcessor { }); } ClientRequest::MockExperimentalMethod { request_id, params } => { - self.mock_experimental_method(request_id, params).await; + self.mock_experimental_method(to_connection_request_id(request_id), params) + .await; } ClientRequest::McpServerOauthLogin { request_id, params } => { - self.mcp_server_oauth_login(request_id, params).await; + self.mcp_server_oauth_login(to_connection_request_id(request_id), params) + .await; } ClientRequest::McpServerRefresh { request_id, params } => { - self.mcp_server_refresh(request_id, params).await; + self.mcp_server_refresh(to_connection_request_id(request_id), params) + .await; } ClientRequest::McpServerStatusList { request_id, params } => { - self.list_mcp_server_status(request_id, params).await; + self.list_mcp_server_status(to_connection_request_id(request_id), params) + .await; } ClientRequest::LoginAccount { request_id, params } => { - self.login_v2(request_id, params).await; + self.login_v2(to_connection_request_id(request_id), params) + .await; } ClientRequest::LogoutAccount { request_id, params: _, } => { - self.logout_v2(request_id).await; + self.logout_v2(to_connection_request_id(request_id)).await; } ClientRequest::CancelLoginAccount { request_id, params } => { - self.cancel_login_v2(request_id, params).await; + self.cancel_login_v2(to_connection_request_id(request_id), params) + .await; } ClientRequest::GetAccount { request_id, params } => { - self.get_account(request_id, params).await; + self.get_account(to_connection_request_id(request_id), params) + .await; } ClientRequest::ResumeConversation { request_id, params } => { - self.handle_resume_conversation(request_id, params).await; + self.handle_resume_conversation(to_connection_request_id(request_id), params) + .await; } ClientRequest::ForkConversation { request_id, params } => { - self.handle_fork_conversation(request_id, params).await; + self.handle_fork_conversation(to_connection_request_id(request_id), params) + .await; } ClientRequest::ArchiveConversation { request_id, params } => { - self.archive_conversation(request_id, params).await; + self.archive_conversation(to_connection_request_id(request_id), params) + .await; } ClientRequest::SendUserMessage { request_id, params } => { - self.send_user_message(request_id, params).await; + self.send_user_message(to_connection_request_id(request_id), params) + .await; } ClientRequest::SendUserTurn { request_id, params } => { - self.send_user_turn(request_id, params).await; + self.send_user_turn(to_connection_request_id(request_id), params) + .await; } ClientRequest::InterruptConversation { request_id, params } => { - self.interrupt_conversation(request_id, params).await; + self.interrupt_conversation(to_connection_request_id(request_id), params) + .await; } ClientRequest::AddConversationListener { request_id, params } => { - self.add_conversation_listener(request_id, params).await; + self.add_conversation_listener(to_connection_request_id(request_id), params) + .await; } ClientRequest::RemoveConversationListener { request_id, params } => { - self.remove_thread_listener(request_id, params).await; + self.remove_thread_listener(to_connection_request_id(request_id), params) + .await; } ClientRequest::GitDiffToRemote { request_id, params } => { - self.git_diff_to_origin(request_id, params.cwd).await; + self.git_diff_to_origin(to_connection_request_id(request_id), params.cwd) + .await; } ClientRequest::LoginApiKey { request_id, params } => { - self.login_api_key_v1(request_id, params).await; + self.login_api_key_v1(to_connection_request_id(request_id), params) + .await; } ClientRequest::LoginChatGpt { request_id, params: _, } => { - self.login_chatgpt_v1(request_id).await; + self.login_chatgpt_v1(to_connection_request_id(request_id)) + .await; } ClientRequest::CancelLoginChatGpt { request_id, params } => { - self.cancel_login_chatgpt(request_id, params.login_id).await; + self.cancel_login_chatgpt(to_connection_request_id(request_id), params.login_id) + .await; } ClientRequest::LogoutChatGpt { request_id, params: _, } => { - self.logout_v1(request_id).await; + self.logout_v1(to_connection_request_id(request_id)).await; } ClientRequest::GetAuthStatus { request_id, params } => { - self.get_auth_status(request_id, params).await; + self.get_auth_status(to_connection_request_id(request_id), params) + .await; } ClientRequest::GetUserSavedConfig { request_id, params: _, } => { - self.get_user_saved_config(request_id).await; + self.get_user_saved_config(to_connection_request_id(request_id)) + .await; } ClientRequest::SetDefaultModel { request_id, params } => { - self.set_default_model(request_id, params).await; + self.set_default_model(to_connection_request_id(request_id), params) + .await; } ClientRequest::GetUserAgent { request_id, params: _, } => { - self.get_user_agent(request_id).await; + self.get_user_agent(to_connection_request_id(request_id)) + .await; } ClientRequest::UserInfo { request_id, params: _, } => { - self.get_user_info(request_id).await; + self.get_user_info(to_connection_request_id(request_id)) + .await; } ClientRequest::FuzzyFileSearch { request_id, params } => { - self.fuzzy_file_search(request_id, params).await; + self.fuzzy_file_search(to_connection_request_id(request_id), params) + .await; } ClientRequest::OneOffCommandExec { request_id, params } => { - self.exec_one_off_command(request_id, params).await; + self.exec_one_off_command(to_connection_request_id(request_id), params) + .await; } ClientRequest::ExecOneOffCommand { request_id, params } => { - self.exec_one_off_command(request_id, params.into()).await; + self.exec_one_off_command(to_connection_request_id(request_id), params.into()) + .await; } ClientRequest::ConfigRead { .. } | ClientRequest::ConfigValueWrite { .. } @@ -706,15 +768,17 @@ impl CodexMessageProcessor { request_id, params: _, } => { - self.get_account_rate_limits(request_id).await; + self.get_account_rate_limits(to_connection_request_id(request_id)) + .await; } ClientRequest::FeedbackUpload { request_id, params } => { - self.upload_feedback(request_id, params).await; + self.upload_feedback(to_connection_request_id(request_id), params) + .await; } } } - async fn login_v2(&mut self, request_id: RequestId, params: LoginAccountParams) { + async fn login_v2(&mut self, request_id: ConnectionRequestId, params: LoginAccountParams) { match params { LoginAccountParams::ApiKey { api_key } => { self.login_api_key_v2(request_id, LoginApiKeyParams { api_key }) @@ -792,7 +856,11 @@ impl CodexMessageProcessor { } } - async fn login_api_key_v1(&mut self, request_id: RequestId, params: LoginApiKeyParams) { + async fn login_api_key_v1( + &mut self, + request_id: ConnectionRequestId, + params: LoginApiKeyParams, + ) { match self.login_api_key_common(¶ms).await { Ok(()) => { self.outgoing @@ -816,7 +884,11 @@ impl CodexMessageProcessor { } } - async fn login_api_key_v2(&mut self, request_id: RequestId, params: LoginApiKeyParams) { + async fn login_api_key_v2( + &mut self, + request_id: ConnectionRequestId, + params: LoginApiKeyParams, + ) { match self.login_api_key_common(¶ms).await { Ok(()) => { let response = codex_app_server_protocol::LoginAccountResponse::ApiKey {}; @@ -880,7 +952,7 @@ impl CodexMessageProcessor { } // Deprecated in favor of login_chatgpt_v2. - async fn login_chatgpt_v1(&mut self, request_id: RequestId) { + async fn login_chatgpt_v1(&mut self, request_id: ConnectionRequestId) { match self.login_chatgpt_common().await { Ok(opts) => match run_login_server(opts) { Ok(server) => { @@ -988,7 +1060,7 @@ impl CodexMessageProcessor { } } - async fn login_chatgpt_v2(&mut self, request_id: RequestId) { + async fn login_chatgpt_v2(&mut self, request_id: ConnectionRequestId) { match self.login_chatgpt_common().await { Ok(opts) => match run_login_server(opts) { Ok(server) => { @@ -1114,7 +1186,7 @@ impl CodexMessageProcessor { } } - async fn cancel_login_chatgpt(&mut self, request_id: RequestId, login_id: Uuid) { + async fn cancel_login_chatgpt(&mut self, request_id: ConnectionRequestId, login_id: Uuid) { match self.cancel_login_chatgpt_common(login_id).await { Ok(()) => { self.outgoing @@ -1132,7 +1204,11 @@ impl CodexMessageProcessor { } } - async fn cancel_login_v2(&mut self, request_id: RequestId, params: CancelLoginAccountParams) { + async fn cancel_login_v2( + &mut self, + request_id: ConnectionRequestId, + params: CancelLoginAccountParams, + ) { let login_id = params.login_id; match Uuid::parse_str(&login_id) { Ok(uuid) => { @@ -1156,7 +1232,7 @@ impl CodexMessageProcessor { async fn login_chatgpt_auth_tokens( &mut self, - request_id: RequestId, + request_id: ConnectionRequestId, access_token: String, chatgpt_account_id: String, chatgpt_plan_type: Option, @@ -1272,7 +1348,7 @@ impl CodexMessageProcessor { .map(CodexAuth::api_auth_mode)) } - async fn logout_v1(&mut self, request_id: RequestId) { + async fn logout_v1(&mut self, request_id: ConnectionRequestId) { match self.logout_common().await { Ok(current_auth_method) => { self.outgoing @@ -1292,7 +1368,7 @@ impl CodexMessageProcessor { } } - async fn logout_v2(&mut self, request_id: RequestId) { + async fn logout_v2(&mut self, request_id: ConnectionRequestId) { match self.logout_common().await { Ok(current_auth_method) => { self.outgoing @@ -1321,7 +1397,7 @@ impl CodexMessageProcessor { } } - async fn get_auth_status(&self, request_id: RequestId, params: GetAuthStatusParams) { + async fn get_auth_status(&self, request_id: ConnectionRequestId, params: GetAuthStatusParams) { let include_token = params.include_token.unwrap_or(false); let do_refresh = params.refresh_token.unwrap_or(false); @@ -1370,7 +1446,7 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn get_account(&self, request_id: RequestId, params: GetAccountParams) { + async fn get_account(&self, request_id: ConnectionRequestId, params: GetAccountParams) { let do_refresh = params.refresh_token; self.refresh_token_if_requested(do_refresh).await; @@ -1422,13 +1498,13 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn get_user_agent(&self, request_id: RequestId) { + async fn get_user_agent(&self, request_id: ConnectionRequestId) { let user_agent = get_codex_user_agent(); let response = GetUserAgentResponse { user_agent }; self.outgoing.send_response(request_id, response).await; } - async fn get_account_rate_limits(&self, request_id: RequestId) { + async fn get_account_rate_limits(&self, request_id: ConnectionRequestId) { match self.fetch_account_rate_limits().await { Ok((rate_limits, rate_limits_by_limit_id)) => { let response = GetAccountRateLimitsResponse { @@ -1517,7 +1593,7 @@ impl CodexMessageProcessor { Ok((primary, rate_limits_by_limit_id)) } - async fn get_user_saved_config(&self, request_id: RequestId) { + async fn get_user_saved_config(&self, request_id: ConnectionRequestId) { let service = ConfigService::new_with_defaults(self.config.codex_home.clone()); let user_saved_config: UserSavedConfig = match service.load_user_saved_config().await { Ok(config) => config, @@ -1538,7 +1614,7 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn get_user_info(&self, request_id: RequestId) { + async fn get_user_info(&self, request_id: ConnectionRequestId) { // Read alleged user email from cached auth (best-effort; not verified). let alleged_user_email = self .auth_manager @@ -1549,7 +1625,11 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn set_default_model(&self, request_id: RequestId, params: SetDefaultModelParams) { + async fn set_default_model( + &self, + request_id: ConnectionRequestId, + params: SetDefaultModelParams, + ) { let SetDefaultModelParams { model, reasoning_effort, @@ -1576,16 +1656,22 @@ impl CodexMessageProcessor { } } - async fn exec_one_off_command(&self, request_id: RequestId, params: CommandExecParams) { + async fn exec_one_off_command( + &self, + request_id: ConnectionRequestId, + params: CommandExecParams, + ) { tracing::debug!("ExecOneOffCommand params: {params:?}"); + let request = request_id.clone(); + if params.command.is_empty() { let error = JSONRPCErrorError { code: INVALID_REQUEST_ERROR_CODE, message: "command must not be empty".to_string(), data: None, }; - self.outgoing.send_error(request_id, error).await; + self.outgoing.send_error(request, error).await; return; } @@ -1603,7 +1689,7 @@ impl CodexMessageProcessor { message: format!("failed to start managed network proxy: {err}"), data: None, }; - self.outgoing.send_error(request_id, error).await; + self.outgoing.send_error(request, error).await; return; } }, @@ -1634,7 +1720,7 @@ impl CodexMessageProcessor { message: format!("invalid sandbox policy: {err}"), data: None, }; - self.outgoing.send_error(request_id, error).await; + self.outgoing.send_error(request, error).await; return; } }, @@ -1643,7 +1729,7 @@ impl CodexMessageProcessor { let codex_linux_sandbox_exe = self.config.codex_linux_sandbox_exe.clone(); let outgoing = self.outgoing.clone(); - let req_id = request_id; + let request_for_task = request; let sandbox_cwd = self.config.cwd.clone(); let started_network_proxy_for_task = started_network_proxy; let use_linux_sandbox_bwrap = self.config.features.enabled(Feature::UseLinuxSandboxBwrap); @@ -1666,7 +1752,7 @@ impl CodexMessageProcessor { stdout: output.stdout.text, stderr: output.stderr.text, }; - outgoing.send_response(req_id, response).await; + outgoing.send_response(request_for_task, response).await; } Err(err) => { let error = JSONRPCErrorError { @@ -1674,7 +1760,7 @@ impl CodexMessageProcessor { message: format!("exec failed: {err}"), data: None, }; - outgoing.send_error(req_id, error).await; + outgoing.send_error(request_for_task, error).await; } } }); @@ -1682,7 +1768,7 @@ impl CodexMessageProcessor { async fn process_new_conversation( &mut self, - request_id: RequestId, + request_id: ConnectionRequestId, params: NewConversationParams, ) { let NewConversationParams { @@ -1783,7 +1869,7 @@ impl CodexMessageProcessor { } } - async fn thread_start(&mut self, request_id: RequestId, params: ThreadStartParams) { + async fn thread_start(&mut self, request_id: ConnectionRequestId, params: ThreadStartParams) { let ThreadStartParams { model, model_provider, @@ -1948,7 +2034,11 @@ impl CodexMessageProcessor { } } - async fn thread_archive(&mut self, request_id: RequestId, params: ThreadArchiveParams) { + async fn thread_archive( + &mut self, + request_id: ConnectionRequestId, + params: ThreadArchiveParams, + ) { // TODO(jif) mostly rewrite this using sqlite after phase 1 let thread_id = match ThreadId::from_string(¶ms.thread_id) { Ok(id) => id, @@ -1998,7 +2088,7 @@ impl CodexMessageProcessor { } } - async fn thread_set_name(&self, request_id: RequestId, params: ThreadSetNameParams) { + async fn thread_set_name(&self, request_id: ConnectionRequestId, params: ThreadSetNameParams) { let ThreadSetNameParams { thread_id, name } = params; let Some(name) = codex_core::util::normalize_thread_name(&name) else { self.send_invalid_request_error( @@ -2028,7 +2118,11 @@ impl CodexMessageProcessor { .await; } - async fn thread_unarchive(&mut self, request_id: RequestId, params: ThreadUnarchiveParams) { + async fn thread_unarchive( + &mut self, + request_id: ConnectionRequestId, + params: ThreadUnarchiveParams, + ) { // TODO(jif) mostly rewrite this using sqlite after phase 1 let thread_id = match ThreadId::from_string(¶ms.thread_id) { Ok(id) => id, @@ -2201,7 +2295,11 @@ impl CodexMessageProcessor { } } - async fn thread_rollback(&mut self, request_id: RequestId, params: ThreadRollbackParams) { + async fn thread_rollback( + &mut self, + request_id: ConnectionRequestId, + params: ThreadRollbackParams, + ) { let ThreadRollbackParams { thread_id, num_turns, @@ -2221,18 +2319,20 @@ impl CodexMessageProcessor { } }; + let request = request_id.clone(); + { let mut map = self.pending_rollbacks.lock().await; if map.contains_key(&thread_id) { self.send_invalid_request_error( - request_id, + request.clone(), "rollback already in progress for this thread".to_string(), ) .await; return; } - map.insert(thread_id, request_id.clone()); + map.insert(thread_id, request.clone()); } if let Err(err) = thread.submit(Op::ThreadRollback { num_turns }).await { @@ -2241,12 +2341,16 @@ impl CodexMessageProcessor { let mut map = self.pending_rollbacks.lock().await; map.remove(&thread_id); - self.send_internal_error(request_id, format!("failed to start rollback: {err}")) + self.send_internal_error(request, format!("failed to start rollback: {err}")) .await; } } - async fn thread_compact_start(&self, request_id: RequestId, params: ThreadCompactStartParams) { + async fn thread_compact_start( + &self, + request_id: ConnectionRequestId, + params: ThreadCompactStartParams, + ) { let ThreadCompactStartParams { thread_id } = params; let (_, thread) = match self.load_thread(&thread_id).await { @@ -2272,7 +2376,7 @@ impl CodexMessageProcessor { async fn thread_background_terminals_clean( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ThreadBackgroundTerminalsCleanParams, ) { let ThreadBackgroundTerminalsCleanParams { thread_id } = params; @@ -2301,7 +2405,7 @@ impl CodexMessageProcessor { } } - async fn thread_list(&self, request_id: RequestId, params: ThreadListParams) { + async fn thread_list(&self, request_id: ConnectionRequestId, params: ThreadListParams) { let ThreadListParams { cursor, limit, @@ -2342,7 +2446,11 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn thread_loaded_list(&self, request_id: RequestId, params: ThreadLoadedListParams) { + async fn thread_loaded_list( + &self, + request_id: ConnectionRequestId, + params: ThreadLoadedListParams, + ) { let ThreadLoadedListParams { cursor, limit } = params; let mut data = self .thread_manager @@ -2397,7 +2505,7 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn thread_read(&mut self, request_id: RequestId, params: ThreadReadParams) { + async fn thread_read(&mut self, request_id: ConnectionRequestId, params: ThreadReadParams) { let ThreadReadParams { thread_id, include_turns, @@ -2548,7 +2656,7 @@ impl CodexMessageProcessor { } } - async fn thread_resume(&mut self, request_id: RequestId, params: ThreadResumeParams) { + async fn thread_resume(&mut self, request_id: ConnectionRequestId, params: ThreadResumeParams) { let ThreadResumeParams { thread_id, history, @@ -2765,7 +2873,7 @@ impl CodexMessageProcessor { } } - async fn thread_fork(&mut self, request_id: RequestId, params: ThreadForkParams) { + async fn thread_fork(&mut self, request_id: ConnectionRequestId, params: ThreadForkParams) { let ThreadForkParams { thread_id, path, @@ -2979,7 +3087,7 @@ impl CodexMessageProcessor { async fn get_thread_summary( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: GetConversationSummaryParams, ) { if let GetConversationSummaryParams::ThreadId { conversation_id } = ¶ms @@ -3045,7 +3153,7 @@ impl CodexMessageProcessor { async fn handle_list_conversations( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ListConversationsParams, ) { let ListConversationsParams { @@ -3209,7 +3317,7 @@ impl CodexMessageProcessor { outgoing: Arc, thread_manager: Arc, config: Arc, - request_id: RequestId, + request_id: ConnectionRequestId, params: ModelListParams, ) { let ModelListParams { limit, cursor } = params; @@ -3272,7 +3380,7 @@ impl CodexMessageProcessor { async fn list_collaboration_modes( outgoing: Arc, thread_manager: Arc, - request_id: RequestId, + request_id: ConnectionRequestId, params: CollaborationModeListParams, ) { let CollaborationModeListParams {} = params; @@ -3283,7 +3391,7 @@ impl CodexMessageProcessor { async fn experimental_feature_list( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ExperimentalFeatureListParams, ) { let ExperimentalFeatureListParams { cursor, limit } = params; @@ -3393,7 +3501,7 @@ impl CodexMessageProcessor { async fn mock_experimental_method( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: MockExperimentalMethodParams, ) { let MockExperimentalMethodParams { value } = params; @@ -3401,7 +3509,7 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn mcp_server_refresh(&self, request_id: RequestId, _params: Option<()>) { + async fn mcp_server_refresh(&self, request_id: ConnectionRequestId, _params: Option<()>) { let config = match self.load_latest_config().await { Ok(config) => config, Err(error) => { @@ -3454,7 +3562,7 @@ impl CodexMessageProcessor { async fn mcp_server_oauth_login( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: McpServerOauthLoginParams, ) { let config = match self.load_latest_config().await { @@ -3551,26 +3659,28 @@ impl CodexMessageProcessor { async fn list_mcp_server_status( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ListMcpServerStatusParams, ) { + let request = request_id.clone(); + let outgoing = Arc::clone(&self.outgoing); let config = match self.load_latest_config().await { Ok(config) => config, Err(error) => { - self.outgoing.send_error(request_id, error).await; + self.outgoing.send_error(request, error).await; return; } }; tokio::spawn(async move { - Self::list_mcp_server_status_task(outgoing, request_id, params, config).await; + Self::list_mcp_server_status_task(outgoing, request, params, config).await; }); } async fn list_mcp_server_status_task( outgoing: Arc, - request_id: RequestId, + request_id: ConnectionRequestId, params: ListMcpServerStatusParams, config: Config, ) { @@ -3653,7 +3763,7 @@ impl CodexMessageProcessor { async fn handle_resume_conversation( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ResumeConversationParams, ) { let ResumeConversationParams { @@ -3861,7 +3971,7 @@ impl CodexMessageProcessor { async fn handle_fork_conversation( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ForkConversationParams, ) { let ForkConversationParams { @@ -4057,7 +4167,7 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn send_invalid_request_error(&self, request_id: RequestId, message: String) { + async fn send_invalid_request_error(&self, request_id: ConnectionRequestId, message: String) { let error = JSONRPCErrorError { code: INVALID_REQUEST_ERROR_CODE, message, @@ -4066,7 +4176,7 @@ impl CodexMessageProcessor { self.outgoing.send_error(request_id, error).await; } - async fn send_internal_error(&self, request_id: RequestId, message: String) { + async fn send_internal_error(&self, request_id: ConnectionRequestId, message: String) { let error = JSONRPCErrorError { code: INTERNAL_ERROR_CODE, message, @@ -4077,7 +4187,7 @@ impl CodexMessageProcessor { async fn archive_conversation( &mut self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ArchiveConversationParams, ) { let ArchiveConversationParams { @@ -4222,7 +4332,11 @@ impl CodexMessageProcessor { }) } - async fn send_user_message(&self, request_id: RequestId, params: SendUserMessageParams) { + async fn send_user_message( + &self, + request_id: ConnectionRequestId, + params: SendUserMessageParams, + ) { let SendUserMessageParams { conversation_id, items, @@ -4266,7 +4380,7 @@ impl CodexMessageProcessor { .await; } - async fn send_user_turn(&self, request_id: RequestId, params: SendUserTurnParams) { + async fn send_user_turn(&self, request_id: ConnectionRequestId, params: SendUserTurnParams) { let SendUserTurnParams { conversation_id, items, @@ -4324,7 +4438,7 @@ impl CodexMessageProcessor { .await; } - async fn apps_list(&self, request_id: RequestId, params: AppsListParams) { + async fn apps_list(&self, request_id: ConnectionRequestId, params: AppsListParams) { let mut config = match self.load_latest_config().await { Ok(config) => config, Err(error) => { @@ -4371,7 +4485,7 @@ impl CodexMessageProcessor { async fn apps_list_task( outgoing: Arc, - request_id: RequestId, + request_id: ConnectionRequestId, params: AppsListParams, config: Config, ) { @@ -4542,7 +4656,7 @@ impl CodexMessageProcessor { .await; } - async fn skills_list(&self, request_id: RequestId, params: SkillsListParams) { + async fn skills_list(&self, request_id: ConnectionRequestId, params: SkillsListParams) { let SkillsListParams { cwds, force_reload, @@ -4608,7 +4722,11 @@ impl CodexMessageProcessor { .await; } - async fn skills_remote_read(&self, request_id: RequestId, _params: SkillsRemoteReadParams) { + async fn skills_remote_read( + &self, + request_id: ConnectionRequestId, + _params: SkillsRemoteReadParams, + ) { match list_remote_skills(&self.config).await { Ok(skills) => { let data = skills @@ -4633,7 +4751,11 @@ impl CodexMessageProcessor { } } - async fn skills_remote_write(&self, request_id: RequestId, params: SkillsRemoteWriteParams) { + async fn skills_remote_write( + &self, + request_id: ConnectionRequestId, + params: SkillsRemoteWriteParams, + ) { let SkillsRemoteWriteParams { hazelnut_id, is_preload, @@ -4663,7 +4785,11 @@ impl CodexMessageProcessor { } } - async fn skills_config_write(&self, request_id: RequestId, params: SkillsConfigWriteParams) { + async fn skills_config_write( + &self, + request_id: ConnectionRequestId, + params: SkillsConfigWriteParams, + ) { let SkillsConfigWriteParams { path, enabled } = params; let edits = vec![ConfigEdit::SetSkillConfig { path, enabled }]; let result = ConfigEditsBuilder::new(&self.config.codex_home) @@ -4696,7 +4822,7 @@ impl CodexMessageProcessor { async fn interrupt_conversation( &mut self, - request_id: RequestId, + request_id: ConnectionRequestId, params: InterruptConversationParams, ) { let InterruptConversationParams { conversation_id } = params; @@ -4710,19 +4836,21 @@ impl CodexMessageProcessor { return; }; + let request = request_id.clone(); + // Record the pending interrupt so we can reply when TurnAborted arrives. { let mut map = self.pending_interrupts.lock().await; map.entry(conversation_id) .or_default() - .push((request_id, ApiVersion::V1)); + .push((request, ApiVersion::V1)); } // Submit the interrupt; we'll respond upon TurnAborted. let _ = conversation.submit(Op::Interrupt).await; } - async fn turn_start(&self, request_id: RequestId, params: TurnStartParams) { + async fn turn_start(&self, request_id: ConnectionRequestId, params: TurnStartParams) { let (_, thread) = match self.load_thread(¶ms.thread_id).await { Ok(v) => v, Err(error) => { @@ -4808,7 +4936,7 @@ impl CodexMessageProcessor { } } - async fn turn_steer(&self, request_id: RequestId, params: TurnSteerParams) { + async fn turn_steer(&self, request_id: ConnectionRequestId, params: TurnSteerParams) { let (_, thread) = match self.load_thread(¶ms.thread_id).await { Ok(v) => v, Err(error) => { @@ -4889,7 +5017,7 @@ impl CodexMessageProcessor { async fn emit_review_started( &self, - request_id: &RequestId, + request_id: &ConnectionRequestId, turn: Turn, parent_thread_id: String, review_thread_id: String, @@ -4913,7 +5041,7 @@ impl CodexMessageProcessor { async fn start_inline_review( &self, - request_id: &RequestId, + request_id: &ConnectionRequestId, parent_thread: Arc, review_request: ReviewRequest, display_text: &str, @@ -4943,7 +5071,7 @@ impl CodexMessageProcessor { async fn start_detached_review( &mut self, - request_id: &RequestId, + request_id: &ConnectionRequestId, parent_thread_id: ThreadId, review_request: ReviewRequest, display_text: &str, @@ -5035,7 +5163,7 @@ impl CodexMessageProcessor { Ok(()) } - async fn review_start(&mut self, request_id: RequestId, params: ReviewStartParams) { + async fn review_start(&mut self, request_id: ConnectionRequestId, params: ReviewStartParams) { let ReviewStartParams { thread_id, target, @@ -5089,7 +5217,11 @@ impl CodexMessageProcessor { } } - async fn turn_interrupt(&mut self, request_id: RequestId, params: TurnInterruptParams) { + async fn turn_interrupt( + &mut self, + request_id: ConnectionRequestId, + params: TurnInterruptParams, + ) { let TurnInterruptParams { thread_id, .. } = params; let (thread_uuid, thread) = match self.load_thread(&thread_id).await { @@ -5100,12 +5232,14 @@ impl CodexMessageProcessor { } }; + let request = request_id.clone(); + // Record the pending interrupt so we can reply when TurnAborted arrives. { let mut map = self.pending_interrupts.lock().await; map.entry(thread_uuid) .or_default() - .push((request_id, ApiVersion::V2)); + .push((request, ApiVersion::V2)); } // Submit the interrupt; we'll respond upon TurnAborted. @@ -5114,7 +5248,7 @@ impl CodexMessageProcessor { async fn add_conversation_listener( &mut self, - request_id: RequestId, + request_id: ConnectionRequestId, params: AddConversationListenerParams, ) { let AddConversationListenerParams { @@ -5137,7 +5271,7 @@ impl CodexMessageProcessor { async fn remove_thread_listener( &mut self, - request_id: RequestId, + request_id: ConnectionRequestId, params: RemoveConversationListenerParams, ) { let RemoveConversationListenerParams { subscription_id } = params; @@ -5267,7 +5401,7 @@ impl CodexMessageProcessor { Ok(subscription_id) } - async fn git_diff_to_origin(&self, request_id: RequestId, cwd: PathBuf) { + async fn git_diff_to_origin(&self, request_id: ConnectionRequestId, cwd: PathBuf) { let diff = git_diff_to_remote(&cwd).await; match diff { Some(value) => { @@ -5288,7 +5422,11 @@ impl CodexMessageProcessor { } } - async fn fuzzy_file_search(&mut self, request_id: RequestId, params: FuzzyFileSearchParams) { + async fn fuzzy_file_search( + &mut self, + request_id: ConnectionRequestId, + params: FuzzyFileSearchParams, + ) { let FuzzyFileSearchParams { query, roots, @@ -5328,7 +5466,7 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } - async fn upload_feedback(&self, request_id: RequestId, params: FeedbackUploadParams) { + async fn upload_feedback(&self, request_id: ConnectionRequestId, params: FeedbackUploadParams) { if !self.config.feedback_enabled { let error = JSONRPCErrorError { code: INVALID_REQUEST_ERROR_CODE, diff --git a/codex-rs/app-server/src/error_code.rs b/codex-rs/app-server/src/error_code.rs index 1ffd889d404..ca93b2f2d33 100644 --- a/codex-rs/app-server/src/error_code.rs +++ b/codex-rs/app-server/src/error_code.rs @@ -1,2 +1,3 @@ pub(crate) const INVALID_REQUEST_ERROR_CODE: i64 = -32600; pub(crate) const INTERNAL_ERROR_CODE: i64 = -32603; +pub(crate) const OVERLOADED_ERROR_CODE: i64 = -32001; diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 5b412d257e8..2a31b2053e7 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -8,14 +8,29 @@ use codex_core::config_loader::CloudRequirementsLoader; use codex_core::config_loader::ConfigLayerStackOrdering; use codex_core::config_loader::LoaderOverrides; use codex_utils_cli::CliConfigOverrides; +use std::collections::HashMap; +use std::collections::HashSet; +use std::collections::VecDeque; use std::io::ErrorKind; use std::io::Result as IoResult; use std::path::PathBuf; +use std::sync::Arc; +use std::sync::RwLock; +use std::sync::atomic::AtomicBool; use crate::message_processor::MessageProcessor; use crate::message_processor::MessageProcessorArgs; -use crate::outgoing_message::OutgoingMessage; +use crate::outgoing_message::ConnectionId; +use crate::outgoing_message::OutgoingEnvelope; use crate::outgoing_message::OutgoingMessageSender; +use crate::transport::CHANNEL_CAPACITY; +use crate::transport::ConnectionState; +use crate::transport::OutboundConnectionState; +use crate::transport::TransportEvent; +use crate::transport::has_initialized_connections; +use crate::transport::route_outgoing_envelope; +use crate::transport::start_stdio_connection; +use crate::transport::start_websocket_acceptor; use codex_app_server_protocol::ConfigLayerSource; use codex_app_server_protocol::ConfigWarningNotification; use codex_app_server_protocol::JSONRPCMessage; @@ -26,13 +41,9 @@ use codex_core::check_execpolicy_for_warnings; use codex_core::config_loader::ConfigLoadError; use codex_core::config_loader::TextRange as CoreTextRange; use codex_feedback::CodexFeedback; -use tokio::io::AsyncBufReadExt; -use tokio::io::AsyncWriteExt; -use tokio::io::BufReader; -use tokio::io::{self}; use tokio::sync::mpsc; +use tokio::task::JoinHandle; use toml::Value as TomlValue; -use tracing::debug; use tracing::error; use tracing::info; use tracing::warn; @@ -51,11 +62,30 @@ mod fuzzy_file_search; mod message_processor; mod models; mod outgoing_message; - -/// Size of the bounded channels used to communicate between tasks. The value -/// is a balance between throughput and memory usage – 128 messages should be -/// plenty for an interactive CLI. -const CHANNEL_CAPACITY: usize = 128; +mod transport; + +pub use crate::transport::AppServerTransport; + +/// Control-plane messages from the processor/transport side to the outbound router task. +/// +/// `run_main_with_transport` now uses two loops/tasks: +/// - processor loop: handles incoming JSON-RPC and request dispatch +/// - outbound loop: performs potentially slow writes to per-connection writers +/// +/// `OutboundControlEvent` keeps those loops coordinated without sharing mutable +/// connection state directly. In particular, the outbound loop needs to know +/// when a connection opens/closes so it can route messages correctly. +enum OutboundControlEvent { + /// Register a new writer for an opened connection. + Opened { + connection_id: ConnectionId, + writer: mpsc::Sender, + initialized: Arc, + opted_out_notification_methods: Arc>>, + }, + /// Remove state for a closed/disconnected connection. + Closed { connection_id: ConnectionId }, +} fn config_warning_from_error( summary: impl Into, @@ -173,32 +203,41 @@ pub async fn run_main( loader_overrides: LoaderOverrides, default_analytics_enabled: bool, ) -> IoResult<()> { - // Set up channels. - let (incoming_tx, mut incoming_rx) = mpsc::channel::(CHANNEL_CAPACITY); - let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY); - - // Task: read from stdin, push to `incoming_tx`. - let stdin_reader_handle = tokio::spawn({ - async move { - let stdin = io::stdin(); - let reader = BufReader::new(stdin); - let mut lines = reader.lines(); - - while let Some(line) = lines.next_line().await.unwrap_or_default() { - match serde_json::from_str::(&line) { - Ok(msg) => { - if incoming_tx.send(msg).await.is_err() { - // Receiver gone – nothing left to do. - break; - } - } - Err(e) => error!("Failed to deserialize JSONRPCMessage: {e}"), - } - } + run_main_with_transport( + codex_linux_sandbox_exe, + cli_config_overrides, + loader_overrides, + default_analytics_enabled, + AppServerTransport::Stdio, + ) + .await +} - debug!("stdin reader finished (EOF)"); +pub async fn run_main_with_transport( + codex_linux_sandbox_exe: Option, + cli_config_overrides: CliConfigOverrides, + loader_overrides: LoaderOverrides, + default_analytics_enabled: bool, + transport: AppServerTransport, +) -> IoResult<()> { + let (transport_event_tx, mut transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY); + let (outbound_control_tx, mut outbound_control_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + + let mut stdio_handles = Vec::>::new(); + let mut websocket_accept_handle = None; + match transport { + AppServerTransport::Stdio => { + start_stdio_connection(transport_event_tx.clone(), &mut stdio_handles).await?; } - }); + AppServerTransport::WebSocket { bind_address } => { + websocket_accept_handle = + Some(start_websocket_acceptor(bind_address, transport_event_tx.clone()).await?); + } + } + let shutdown_when_no_connections = matches!(transport, AppServerTransport::Stdio); // Parse CLI overrides once and derive the base Config eagerly so later // components do not need to work with raw TOML values. @@ -329,15 +368,76 @@ pub async fn run_main( } } - // Task: process incoming messages. + let transport_event_tx_for_outbound = transport_event_tx.clone(); + let outbound_handle = tokio::spawn(async move { + let mut outbound_connections = HashMap::::new(); + let mut pending_closed_connections = VecDeque::::new(); + loop { + tokio::select! { + biased; + event = outbound_control_rx.recv() => { + let Some(event) = event else { + break; + }; + match event { + OutboundControlEvent::Opened { + connection_id, + writer, + initialized, + opted_out_notification_methods, + } => { + outbound_connections.insert( + connection_id, + OutboundConnectionState::new( + writer, + initialized, + opted_out_notification_methods, + ), + ); + } + OutboundControlEvent::Closed { connection_id } => { + outbound_connections.remove(&connection_id); + } + } + } + envelope = outgoing_rx.recv() => { + let Some(envelope) = envelope else { + break; + }; + let disconnected_connections = + route_outgoing_envelope(&mut outbound_connections, envelope).await; + pending_closed_connections.extend(disconnected_connections); + } + } + + while let Some(connection_id) = pending_closed_connections.front().copied() { + match transport_event_tx_for_outbound + .try_send(TransportEvent::ConnectionClosed { connection_id }) + { + Ok(()) => { + pending_closed_connections.pop_front(); + } + Err(mpsc::error::TrySendError::Full(_)) => { + break; + } + Err(mpsc::error::TrySendError::Closed(_)) => { + return; + } + } + } + } + info!("outbound router task exited (channel closed)"); + }); + let processor_handle = tokio::spawn({ - let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx); + let outgoing_message_sender = Arc::new(OutgoingMessageSender::new(outgoing_tx)); + let outbound_control_tx = outbound_control_tx; let cli_overrides: Vec<(String, TomlValue)> = cli_kv_overrides.clone(); let loader_overrides = loader_overrides_for_config_api; let mut processor = MessageProcessor::new(MessageProcessorArgs { outgoing: outgoing_message_sender, codex_linux_sandbox_exe, - config: std::sync::Arc::new(config), + config: Arc::new(config), cli_overrides, loader_overrides, cloud_requirements: cloud_requirements.clone(), @@ -345,25 +445,107 @@ pub async fn run_main( config_warnings, }); let mut thread_created_rx = processor.thread_created_receiver(); + let mut connections = HashMap::::new(); async move { let mut listen_for_threads = true; loop { tokio::select! { - msg = incoming_rx.recv() => { - let Some(msg) = msg else { + event = transport_event_rx.recv() => { + let Some(event) = event else { break; }; - match msg { - JSONRPCMessage::Request(r) => processor.process_request(r).await, - JSONRPCMessage::Response(r) => processor.process_response(r).await, - JSONRPCMessage::Notification(n) => processor.process_notification(n).await, - JSONRPCMessage::Error(e) => processor.process_error(e).await, + match event { + TransportEvent::ConnectionOpened { connection_id, writer } => { + let outbound_initialized = Arc::new(AtomicBool::new(false)); + let outbound_opted_out_notification_methods = + Arc::new(RwLock::new(HashSet::new())); + if outbound_control_tx + .send(OutboundControlEvent::Opened { + connection_id, + writer, + initialized: Arc::clone(&outbound_initialized), + opted_out_notification_methods: Arc::clone( + &outbound_opted_out_notification_methods, + ), + }) + .await + .is_err() + { + break; + } + connections.insert( + connection_id, + ConnectionState::new( + outbound_initialized, + outbound_opted_out_notification_methods, + ), + ); + } + TransportEvent::ConnectionClosed { connection_id } => { + if outbound_control_tx + .send(OutboundControlEvent::Closed { connection_id }) + .await + .is_err() + { + break; + } + connections.remove(&connection_id); + if shutdown_when_no_connections && connections.is_empty() { + break; + } + } + TransportEvent::IncomingMessage { connection_id, message } => { + match message { + JSONRPCMessage::Request(request) => { + let Some(connection_state) = connections.get_mut(&connection_id) else { + warn!("dropping request from unknown connection: {:?}", connection_id); + continue; + }; + let was_initialized = connection_state.session.initialized; + processor + .process_request( + connection_id, + request, + &mut connection_state.session, + &connection_state.outbound_initialized, + ) + .await; + if let Ok(mut opted_out_notification_methods) = connection_state + .outbound_opted_out_notification_methods + .write() + { + *opted_out_notification_methods = connection_state + .session + .opted_out_notification_methods + .clone(); + } else { + warn!( + "failed to update outbound opted-out notifications" + ); + } + if !was_initialized && connection_state.session.initialized { + processor.send_initialize_notifications().await; + } + } + JSONRPCMessage::Response(response) => { + processor.process_response(response).await; + } + JSONRPCMessage::Notification(notification) => { + processor.process_notification(notification).await; + } + JSONRPCMessage::Error(err) => { + processor.process_error(err).await; + } + } + } } } created = thread_created_rx.recv(), if listen_for_threads => { match created { Ok(thread_id) => { - processor.try_attach_thread_listener(thread_id).await; + if has_initialized_connections(&connections) { + processor.try_attach_thread_listener(thread_id).await; + } } Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => { // TODO(jif) handle lag. @@ -384,33 +566,18 @@ pub async fn run_main( } }); - // Task: write outgoing messages to stdout. - let stdout_writer_handle = tokio::spawn(async move { - let mut stdout = io::stdout(); - while let Some(outgoing_message) = outgoing_rx.recv().await { - let Ok(value) = serde_json::to_value(outgoing_message) else { - error!("Failed to convert OutgoingMessage to JSON value"); - continue; - }; - match serde_json::to_string(&value) { - Ok(mut json) => { - json.push('\n'); - if let Err(e) = stdout.write_all(json.as_bytes()).await { - error!("Failed to write to stdout: {e}"); - break; - } - } - Err(e) => error!("Failed to serialize JSONRPCMessage: {e}"), - } - } + drop(transport_event_tx); - info!("stdout writer exited (channel closed)"); - }); + let _ = processor_handle.await; + let _ = outbound_handle.await; - // Wait for all tasks to finish. The typical exit path is the stdin reader - // hitting EOF which, once it drops `incoming_tx`, propagates shutdown to - // the processor and then to the stdout task. - let _ = tokio::join!(stdin_reader_handle, processor_handle, stdout_writer_handle); + if let Some(handle) = websocket_accept_handle { + handle.abort(); + } + + for handle in stdio_handles { + let _ = handle.await; + } Ok(()) } diff --git a/codex-rs/app-server/src/main.rs b/codex-rs/app-server/src/main.rs index c436300a2d7..5c4e5eacc7a 100644 --- a/codex-rs/app-server/src/main.rs +++ b/codex-rs/app-server/src/main.rs @@ -1,4 +1,6 @@ -use codex_app_server::run_main; +use clap::Parser; +use codex_app_server::AppServerTransport; +use codex_app_server::run_main_with_transport; use codex_arg0::arg0_dispatch_or_else; use codex_core::config_loader::LoaderOverrides; use codex_utils_cli::CliConfigOverrides; @@ -8,19 +10,34 @@ use std::path::PathBuf; // managed config file without writing to /etc. const MANAGED_CONFIG_PATH_ENV_VAR: &str = "CODEX_APP_SERVER_MANAGED_CONFIG_PATH"; +#[derive(Debug, Parser)] +struct AppServerArgs { + /// Transport endpoint URL. Supported values: `stdio://` (default), + /// `ws://IP:PORT`. + #[arg( + long = "listen", + value_name = "URL", + default_value = AppServerTransport::DEFAULT_LISTEN_URL + )] + listen: AppServerTransport, +} + fn main() -> anyhow::Result<()> { arg0_dispatch_or_else(|codex_linux_sandbox_exe| async move { + let args = AppServerArgs::parse(); let managed_config_path = managed_config_path_from_debug_env(); let loader_overrides = LoaderOverrides { managed_config_path, ..Default::default() }; + let transport = args.listen; - run_main( + run_main_with_transport( codex_linux_sandbox_exe, CliConfigOverrides::default(), loader_overrides, false, + transport, ) .await?; Ok(()) diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 368d24e7270..e7d4a0b6198 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::path::PathBuf; use std::sync::Arc; use std::sync::RwLock; @@ -8,6 +9,8 @@ use crate::codex_message_processor::CodexMessageProcessor; use crate::codex_message_processor::CodexMessageProcessorArgs; use crate::config_api::ConfigApi; use crate::error_code::INVALID_REQUEST_ERROR_CODE; +use crate::outgoing_message::ConnectionId; +use crate::outgoing_message::ConnectionRequestId; use crate::outgoing_message::OutgoingMessageSender; use async_trait::async_trait; use codex_app_server_protocol::ChatgptAuthTokensRefreshParams; @@ -26,7 +29,6 @@ use codex_app_server_protocol::JSONRPCErrorError; use codex_app_server_protocol::JSONRPCNotification; use codex_app_server_protocol::JSONRPCRequest; use codex_app_server_protocol::JSONRPCResponse; -use codex_app_server_protocol::RequestId; use codex_app_server_protocol::ServerNotification; use codex_app_server_protocol::ServerRequestPayload; use codex_app_server_protocol::experimental_required_message; @@ -112,13 +114,18 @@ pub(crate) struct MessageProcessor { codex_message_processor: CodexMessageProcessor, config_api: ConfigApi, config: Arc, - initialized: bool, - experimental_api_enabled: Arc, - config_warnings: Vec, + config_warnings: Arc>, +} + +#[derive(Clone, Debug, Default)] +pub(crate) struct ConnectionSessionState { + pub(crate) initialized: bool, + experimental_api_enabled: bool, + pub(crate) opted_out_notification_methods: HashSet, } pub(crate) struct MessageProcessorArgs { - pub(crate) outgoing: OutgoingMessageSender, + pub(crate) outgoing: Arc, pub(crate) codex_linux_sandbox_exe: Option, pub(crate) config: Arc, pub(crate) cli_overrides: Vec<(String, TomlValue)>, @@ -142,8 +149,6 @@ impl MessageProcessor { feedback, config_warnings, } = args; - let outgoing = Arc::new(outgoing); - let experimental_api_enabled = Arc::new(AtomicBool::new(false)); let auth_manager = AuthManager::shared( config.codex_home.clone(), false, @@ -181,14 +186,21 @@ impl MessageProcessor { codex_message_processor, config_api, config, - initialized: false, - experimental_api_enabled, - config_warnings, + config_warnings: Arc::new(config_warnings), } } - pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) { - let request_id = request.id.clone(); + pub(crate) async fn process_request( + &mut self, + connection_id: ConnectionId, + request: JSONRPCRequest, + session: &mut ConnectionSessionState, + outbound_initialized: &AtomicBool, + ) { + let request_id = ConnectionRequestId { + connection_id, + request_id: request.id.clone(), + }; let request_json = match serde_json::to_value(&request) { Ok(request_json) => request_json, Err(err) => { @@ -219,7 +231,11 @@ impl MessageProcessor { // Handle Initialize internally so CodexMessageProcessor does not have to concern // itself with the `initialized` bool. ClientRequest::Initialize { request_id, params } => { - if self.initialized { + let request_id = ConnectionRequestId { + connection_id, + request_id, + }; + if session.initialized { let error = JSONRPCErrorError { code: INVALID_REQUEST_ERROR_CODE, message: "Already initialized".to_string(), @@ -228,6 +244,12 @@ impl MessageProcessor { self.outgoing.send_error(request_id, error).await; return; } else { + // TODO(maxj): Revisit capability scoping for `experimental_api_enabled`. + // Current behavior is per-connection. Reviewer feedback notes this can + // create odd cross-client behavior (for example dynamic tool calls on a + // shared thread when another connected client did not opt into + // experimental API). Proposed direction is instance-global first-write-wins + // with initialize-time mismatch rejection. let (experimental_api_enabled, opt_out_notification_methods) = match params.capabilities { Some(capabilities) => ( @@ -238,11 +260,9 @@ impl MessageProcessor { ), None => (false, Vec::new()), }; - self.experimental_api_enabled - .store(experimental_api_enabled, Ordering::Relaxed); - self.outgoing - .set_opted_out_notification_methods(opt_out_notification_methods) - .await; + session.experimental_api_enabled = experimental_api_enabled; + session.opted_out_notification_methods = + opt_out_notification_methods.into_iter().collect(); let ClientInfo { name, title: _title, @@ -258,7 +278,7 @@ impl MessageProcessor { ), data: None, }; - self.outgoing.send_error(request_id, error).await; + self.outgoing.send_error(request_id.clone(), error).await; return; } SetOriginatorError::AlreadyInitialized => { @@ -279,22 +299,13 @@ impl MessageProcessor { let response = InitializeResponse { user_agent }; self.outgoing.send_response(request_id, response).await; - self.initialized = true; - if !self.config_warnings.is_empty() { - for notification in self.config_warnings.drain(..) { - self.outgoing - .send_server_notification(ServerNotification::ConfigWarning( - notification, - )) - .await; - } - } - + session.initialized = true; + outbound_initialized.store(true, Ordering::Release); return; } } _ => { - if !self.initialized { + if !session.initialized { let error = JSONRPCErrorError { code: INVALID_REQUEST_ERROR_CODE, message: "Not initialized".to_string(), @@ -307,7 +318,7 @@ impl MessageProcessor { } if let Some(reason) = codex_request.experimental_reason() - && !self.experimental_api_enabled.load(Ordering::Relaxed) + && !session.experimental_api_enabled { let error = JSONRPCErrorError { code: INVALID_REQUEST_ERROR_CODE, @@ -320,22 +331,49 @@ impl MessageProcessor { match codex_request { ClientRequest::ConfigRead { request_id, params } => { - self.handle_config_read(request_id, params).await; + self.handle_config_read( + ConnectionRequestId { + connection_id, + request_id, + }, + params, + ) + .await; } ClientRequest::ConfigValueWrite { request_id, params } => { - self.handle_config_value_write(request_id, params).await; + self.handle_config_value_write( + ConnectionRequestId { + connection_id, + request_id, + }, + params, + ) + .await; } ClientRequest::ConfigBatchWrite { request_id, params } => { - self.handle_config_batch_write(request_id, params).await; + self.handle_config_batch_write( + ConnectionRequestId { + connection_id, + request_id, + }, + params, + ) + .await; } ClientRequest::ConfigRequirementsRead { request_id, params: _, } => { - self.handle_config_requirements_read(request_id).await; + self.handle_config_requirements_read(ConnectionRequestId { + connection_id, + request_id, + }) + .await; } other => { - self.codex_message_processor.process_request(other).await; + self.codex_message_processor + .process_request(connection_id, other) + .await; } } } @@ -350,10 +388,15 @@ impl MessageProcessor { self.codex_message_processor.thread_created_receiver() } - pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) { - if !self.initialized { - return; + pub(crate) async fn send_initialize_notifications(&self) { + for notification in self.config_warnings.iter().cloned() { + self.outgoing + .send_server_notification(ServerNotification::ConfigWarning(notification)) + .await; } + } + + pub(crate) async fn try_attach_thread_listener(&mut self, thread_id: ThreadId) { self.codex_message_processor .try_attach_thread_listener(thread_id) .await; @@ -372,7 +415,7 @@ impl MessageProcessor { self.outgoing.notify_client_error(err.id, err.error).await; } - async fn handle_config_read(&self, request_id: RequestId, params: ConfigReadParams) { + async fn handle_config_read(&self, request_id: ConnectionRequestId, params: ConfigReadParams) { match self.config_api.read(params).await { Ok(response) => self.outgoing.send_response(request_id, response).await, Err(error) => self.outgoing.send_error(request_id, error).await, @@ -381,7 +424,7 @@ impl MessageProcessor { async fn handle_config_value_write( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ConfigValueWriteParams, ) { match self.config_api.write_value(params).await { @@ -392,7 +435,7 @@ impl MessageProcessor { async fn handle_config_batch_write( &self, - request_id: RequestId, + request_id: ConnectionRequestId, params: ConfigBatchWriteParams, ) { match self.config_api.batch_write(params).await { @@ -401,7 +444,7 @@ impl MessageProcessor { } } - async fn handle_config_requirements_read(&self, request_id: RequestId) { + async fn handle_config_requirements_read(&self, request_id: ConnectionRequestId) { match self.config_api.config_requirements_read().await { Ok(response) => self.outgoing.send_response(request_id, response).await, Err(error) => self.outgoing.send_error(request_id, error).await, diff --git a/codex-rs/app-server/src/outgoing_message.rs b/codex-rs/app-server/src/outgoing_message.rs index c9fd9ebda84..9740393efda 100644 --- a/codex-rs/app-server/src/outgoing_message.rs +++ b/codex-rs/app-server/src/outgoing_message.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::collections::HashSet; use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; @@ -20,35 +19,44 @@ use crate::error_code::INTERNAL_ERROR_CODE; #[cfg(test)] use codex_protocol::account::PlanType; +/// Stable identifier for a transport connection. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub(crate) struct ConnectionId(pub(crate) u64); + +/// Stable identifier for a client request scoped to a transport connection. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub(crate) struct ConnectionRequestId { + pub(crate) connection_id: ConnectionId, + pub(crate) request_id: RequestId, +} + +#[derive(Debug, Clone)] +pub(crate) enum OutgoingEnvelope { + ToConnection { + connection_id: ConnectionId, + message: OutgoingMessage, + }, + Broadcast { + message: OutgoingMessage, + }, +} + /// Sends messages to the client and manages request callbacks. pub(crate) struct OutgoingMessageSender { - next_request_id: AtomicI64, - sender: mpsc::Sender, + next_server_request_id: AtomicI64, + sender: mpsc::Sender, request_id_to_callback: Mutex>>, - opted_out_notification_methods: Mutex>, } impl OutgoingMessageSender { - pub(crate) fn new(sender: mpsc::Sender) -> Self { + pub(crate) fn new(sender: mpsc::Sender) -> Self { Self { - next_request_id: AtomicI64::new(0), + next_server_request_id: AtomicI64::new(0), sender, request_id_to_callback: Mutex::new(HashMap::new()), - opted_out_notification_methods: Mutex::new(HashSet::new()), } } - pub(crate) async fn set_opted_out_notification_methods(&self, methods: Vec) { - let mut opted_out = self.opted_out_notification_methods.lock().await; - opted_out.clear(); - opted_out.extend(methods); - } - - async fn should_skip_notification(&self, method: &str) -> bool { - let opted_out = self.opted_out_notification_methods.lock().await; - opted_out.contains(method) - } - pub(crate) async fn send_request( &self, request: ServerRequestPayload, @@ -61,7 +69,7 @@ impl OutgoingMessageSender { &self, request: ServerRequestPayload, ) -> (RequestId, oneshot::Receiver) { - let id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::Relaxed)); + let id = RequestId::Integer(self.next_server_request_id.fetch_add(1, Ordering::Relaxed)); let outgoing_message_id = id.clone(); let (tx_approve, rx_approve) = oneshot::channel(); { @@ -71,7 +79,13 @@ impl OutgoingMessageSender { let outgoing_message = OutgoingMessage::Request(request.request_with_id(outgoing_message_id.clone())); - if let Err(err) = self.sender.send(outgoing_message).await { + if let Err(err) = self + .sender + .send(OutgoingEnvelope::Broadcast { + message: outgoing_message, + }) + .await + { warn!("failed to send request {outgoing_message_id:?} to client: {err:?}"); let mut request_id_to_callback = self.request_id_to_callback.lock().await; request_id_to_callback.remove(&outgoing_message_id); @@ -121,17 +135,31 @@ impl OutgoingMessageSender { entry.is_some() } - pub(crate) async fn send_response(&self, id: RequestId, response: T) { + pub(crate) async fn send_response( + &self, + request_id: ConnectionRequestId, + response: T, + ) { match serde_json::to_value(response) { Ok(result) => { - let outgoing_message = OutgoingMessage::Response(OutgoingResponse { id, result }); - if let Err(err) = self.sender.send(outgoing_message).await { + let outgoing_message = OutgoingMessage::Response(OutgoingResponse { + id: request_id.request_id, + result, + }); + if let Err(err) = self + .sender + .send(OutgoingEnvelope::ToConnection { + connection_id: request_id.connection_id, + message: outgoing_message, + }) + .await + { warn!("failed to send response to client: {err:?}"); } } Err(err) => { self.send_error( - id, + request_id, JSONRPCErrorError { code: INTERNAL_ERROR_CODE, message: format!("failed to serialize response: {err}"), @@ -144,13 +172,11 @@ impl OutgoingMessageSender { } pub(crate) async fn send_server_notification(&self, notification: ServerNotification) { - let method = notification.to_string(); - if self.should_skip_notification(&method).await { - return; - } if let Err(err) = self .sender - .send(OutgoingMessage::AppServerNotification(notification)) + .send(OutgoingEnvelope::Broadcast { + message: OutgoingMessage::AppServerNotification(notification), + }) .await { warn!("failed to send server notification to client: {err:?}"); @@ -160,21 +186,35 @@ impl OutgoingMessageSender { /// All notifications should be migrated to [`ServerNotification`] and /// [`OutgoingMessage::Notification`] should be removed. pub(crate) async fn send_notification(&self, notification: OutgoingNotification) { - if self - .should_skip_notification(notification.method.as_str()) + let outgoing_message = OutgoingMessage::Notification(notification); + if let Err(err) = self + .sender + .send(OutgoingEnvelope::Broadcast { + message: outgoing_message, + }) .await { - return; - } - let outgoing_message = OutgoingMessage::Notification(notification); - if let Err(err) = self.sender.send(outgoing_message).await { warn!("failed to send notification to client: {err:?}"); } } - pub(crate) async fn send_error(&self, id: RequestId, error: JSONRPCErrorError) { - let outgoing_message = OutgoingMessage::Error(OutgoingError { id, error }); - if let Err(err) = self.sender.send(outgoing_message).await { + pub(crate) async fn send_error( + &self, + request_id: ConnectionRequestId, + error: JSONRPCErrorError, + ) { + let outgoing_message = OutgoingMessage::Error(OutgoingError { + id: request_id.request_id, + error, + }); + if let Err(err) = self + .sender + .send(OutgoingEnvelope::ToConnection { + connection_id: request_id.connection_id, + message: outgoing_message, + }) + .await + { warn!("failed to send error to client: {err:?}"); } } @@ -214,6 +254,8 @@ pub(crate) struct OutgoingError { #[cfg(test)] mod tests { + use std::time::Duration; + use codex_app_server_protocol::AccountLoginCompletedNotification; use codex_app_server_protocol::AccountRateLimitsUpdatedNotification; use codex_app_server_protocol::AccountUpdatedNotification; @@ -224,6 +266,7 @@ mod tests { use codex_app_server_protocol::RateLimitWindow; use pretty_assertions::assert_eq; use serde_json::json; + use tokio::time::timeout; use uuid::Uuid; use super::*; @@ -364,4 +407,75 @@ mod tests { "ensure the notification serializes correctly" ); } + + #[tokio::test] + async fn send_response_routes_to_target_connection() { + let (tx, mut rx) = mpsc::channel::(4); + let outgoing = OutgoingMessageSender::new(tx); + let request_id = ConnectionRequestId { + connection_id: ConnectionId(42), + request_id: RequestId::Integer(7), + }; + + outgoing + .send_response(request_id.clone(), json!({ "ok": true })) + .await; + + let envelope = timeout(Duration::from_secs(1), rx.recv()) + .await + .expect("should receive envelope before timeout") + .expect("channel should contain one message"); + + match envelope { + OutgoingEnvelope::ToConnection { + connection_id, + message, + } => { + assert_eq!(connection_id, ConnectionId(42)); + let OutgoingMessage::Response(response) = message else { + panic!("expected response message"); + }; + assert_eq!(response.id, request_id.request_id); + assert_eq!(response.result, json!({ "ok": true })); + } + other => panic!("expected targeted response envelope, got: {other:?}"), + } + } + + #[tokio::test] + async fn send_error_routes_to_target_connection() { + let (tx, mut rx) = mpsc::channel::(4); + let outgoing = OutgoingMessageSender::new(tx); + let request_id = ConnectionRequestId { + connection_id: ConnectionId(9), + request_id: RequestId::Integer(3), + }; + let error = JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message: "boom".to_string(), + data: None, + }; + + outgoing.send_error(request_id.clone(), error.clone()).await; + + let envelope = timeout(Duration::from_secs(1), rx.recv()) + .await + .expect("should receive envelope before timeout") + .expect("channel should contain one message"); + + match envelope { + OutgoingEnvelope::ToConnection { + connection_id, + message, + } => { + assert_eq!(connection_id, ConnectionId(9)); + let OutgoingMessage::Error(outgoing_error) = message else { + panic!("expected error message"); + }; + assert_eq!(outgoing_error.id, RequestId::Integer(3)); + assert_eq!(outgoing_error.error, error); + } + other => panic!("expected targeted error envelope, got: {other:?}"), + } + } } diff --git a/codex-rs/app-server/src/transport.rs b/codex-rs/app-server/src/transport.rs new file mode 100644 index 00000000000..cbfd263a555 --- /dev/null +++ b/codex-rs/app-server/src/transport.rs @@ -0,0 +1,749 @@ +use crate::error_code::OVERLOADED_ERROR_CODE; +use crate::message_processor::ConnectionSessionState; +use crate::outgoing_message::ConnectionId; +use crate::outgoing_message::OutgoingEnvelope; +use crate::outgoing_message::OutgoingError; +use crate::outgoing_message::OutgoingMessage; +use codex_app_server_protocol::JSONRPCErrorError; +use codex_app_server_protocol::JSONRPCMessage; +use futures::SinkExt; +use futures::StreamExt; +use owo_colors::OwoColorize; +use owo_colors::Stream; +use owo_colors::Style; +use std::collections::HashMap; +use std::collections::HashSet; +use std::io::ErrorKind; +use std::io::Result as IoResult; +use std::net::SocketAddr; +use std::str::FromStr; +use std::sync::Arc; +use std::sync::RwLock; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::AtomicU64; +use std::sync::atomic::Ordering; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tokio::io::{self}; +use tokio::net::TcpListener; +use tokio::net::TcpStream; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; +use tokio_tungstenite::accept_async; +use tokio_tungstenite::tungstenite::Message as WebSocketMessage; +use tracing::debug; +use tracing::error; +use tracing::info; +use tracing::warn; + +/// Size of the bounded channels used to communicate between tasks. The value +/// is a balance between throughput and memory usage - 128 messages should be +/// plenty for an interactive CLI. +pub(crate) const CHANNEL_CAPACITY: usize = 128; + +fn colorize(text: &str, style: Style) -> String { + text.if_supports_color(Stream::Stderr, |value| value.style(style)) + .to_string() +} + +#[allow(clippy::print_stderr)] +fn print_websocket_startup_banner(addr: SocketAddr) { + let title = colorize("codex app-server (WebSockets)", Style::new().bold().cyan()); + let listening_label = colorize("listening on:", Style::new().dimmed()); + let listen_url = colorize(&format!("ws://{addr}"), Style::new().green()); + let note_label = colorize("note:", Style::new().dimmed()); + eprintln!("{title}"); + eprintln!(" {listening_label} {listen_url}"); + if addr.ip().is_loopback() { + eprintln!( + " {note_label} binds localhost only (use SSH port-forwarding for remote access)" + ); + } else { + eprintln!( + " {note_label} this is a raw WS server; consider running behind TLS/auth for real remote use" + ); + } +} + +#[allow(clippy::print_stderr)] +fn print_websocket_connection(peer_addr: SocketAddr) { + let connected_label = colorize("websocket client connected from", Style::new().dimmed()); + eprintln!("{connected_label} {peer_addr}"); +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum AppServerTransport { + Stdio, + WebSocket { bind_address: SocketAddr }, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum AppServerTransportParseError { + UnsupportedListenUrl(String), + InvalidWebSocketListenUrl(String), +} + +impl std::fmt::Display for AppServerTransportParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AppServerTransportParseError::UnsupportedListenUrl(listen_url) => write!( + f, + "unsupported --listen URL `{listen_url}`; expected `stdio://` or `ws://IP:PORT`" + ), + AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!( + f, + "invalid websocket --listen URL `{listen_url}`; expected `ws://IP:PORT`" + ), + } + } +} + +impl std::error::Error for AppServerTransportParseError {} + +impl AppServerTransport { + pub const DEFAULT_LISTEN_URL: &'static str = "stdio://"; + + pub fn from_listen_url(listen_url: &str) -> Result { + if listen_url == Self::DEFAULT_LISTEN_URL { + return Ok(Self::Stdio); + } + + if let Some(socket_addr) = listen_url.strip_prefix("ws://") { + let bind_address = socket_addr.parse::().map_err(|_| { + AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string()) + })?; + return Ok(Self::WebSocket { bind_address }); + } + + Err(AppServerTransportParseError::UnsupportedListenUrl( + listen_url.to_string(), + )) + } +} + +impl FromStr for AppServerTransport { + type Err = AppServerTransportParseError; + + fn from_str(s: &str) -> Result { + Self::from_listen_url(s) + } +} + +#[derive(Debug)] +pub(crate) enum TransportEvent { + ConnectionOpened { + connection_id: ConnectionId, + writer: mpsc::Sender, + }, + ConnectionClosed { + connection_id: ConnectionId, + }, + IncomingMessage { + connection_id: ConnectionId, + message: JSONRPCMessage, + }, +} + +pub(crate) struct ConnectionState { + pub(crate) outbound_initialized: Arc, + pub(crate) outbound_opted_out_notification_methods: Arc>>, + pub(crate) session: ConnectionSessionState, +} + +impl ConnectionState { + pub(crate) fn new( + outbound_initialized: Arc, + outbound_opted_out_notification_methods: Arc>>, + ) -> Self { + Self { + outbound_initialized, + outbound_opted_out_notification_methods, + session: ConnectionSessionState::default(), + } + } +} + +pub(crate) struct OutboundConnectionState { + pub(crate) initialized: Arc, + pub(crate) opted_out_notification_methods: Arc>>, + pub(crate) writer: mpsc::Sender, +} + +impl OutboundConnectionState { + pub(crate) fn new( + writer: mpsc::Sender, + initialized: Arc, + opted_out_notification_methods: Arc>>, + ) -> Self { + Self { + initialized, + opted_out_notification_methods, + writer, + } + } +} + +pub(crate) async fn start_stdio_connection( + transport_event_tx: mpsc::Sender, + stdio_handles: &mut Vec>, +) -> IoResult<()> { + let connection_id = ConnectionId(0); + let (writer_tx, mut writer_rx) = mpsc::channel::(CHANNEL_CAPACITY); + let writer_tx_for_reader = writer_tx.clone(); + transport_event_tx + .send(TransportEvent::ConnectionOpened { + connection_id, + writer: writer_tx, + }) + .await + .map_err(|_| std::io::Error::new(ErrorKind::BrokenPipe, "processor unavailable"))?; + + let transport_event_tx_for_reader = transport_event_tx.clone(); + stdio_handles.push(tokio::spawn(async move { + let stdin = io::stdin(); + let reader = BufReader::new(stdin); + let mut lines = reader.lines(); + + loop { + match lines.next_line().await { + Ok(Some(line)) => { + if !forward_incoming_message( + &transport_event_tx_for_reader, + &writer_tx_for_reader, + connection_id, + &line, + ) + .await + { + break; + } + } + Ok(None) => break, + Err(err) => { + error!("Failed reading stdin: {err}"); + break; + } + } + } + + let _ = transport_event_tx_for_reader + .send(TransportEvent::ConnectionClosed { connection_id }) + .await; + debug!("stdin reader finished (EOF)"); + })); + + stdio_handles.push(tokio::spawn(async move { + let mut stdout = io::stdout(); + while let Some(outgoing_message) = writer_rx.recv().await { + let Some(mut json) = serialize_outgoing_message(outgoing_message) else { + continue; + }; + json.push('\n'); + if let Err(err) = stdout.write_all(json.as_bytes()).await { + error!("Failed to write to stdout: {err}"); + break; + } + } + info!("stdout writer exited (channel closed)"); + })); + + Ok(()) +} + +pub(crate) async fn start_websocket_acceptor( + bind_address: SocketAddr, + transport_event_tx: mpsc::Sender, +) -> IoResult> { + let listener = TcpListener::bind(bind_address).await?; + let local_addr = listener.local_addr()?; + print_websocket_startup_banner(local_addr); + info!("app-server websocket listening on ws://{local_addr}"); + + let connection_counter = Arc::new(AtomicU64::new(1)); + Ok(tokio::spawn(async move { + loop { + match listener.accept().await { + Ok((stream, peer_addr)) => { + print_websocket_connection(peer_addr); + let connection_id = + ConnectionId(connection_counter.fetch_add(1, Ordering::Relaxed)); + let transport_event_tx_for_connection = transport_event_tx.clone(); + tokio::spawn(async move { + run_websocket_connection( + connection_id, + stream, + transport_event_tx_for_connection, + ) + .await; + }); + } + Err(err) => { + error!("failed to accept websocket connection: {err}"); + } + } + } + })) +} + +async fn run_websocket_connection( + connection_id: ConnectionId, + stream: TcpStream, + transport_event_tx: mpsc::Sender, +) { + let websocket_stream = match accept_async(stream).await { + Ok(stream) => stream, + Err(err) => { + warn!("failed to complete websocket handshake: {err}"); + return; + } + }; + + let (writer_tx, mut writer_rx) = mpsc::channel::(CHANNEL_CAPACITY); + let writer_tx_for_reader = writer_tx.clone(); + if transport_event_tx + .send(TransportEvent::ConnectionOpened { + connection_id, + writer: writer_tx, + }) + .await + .is_err() + { + return; + } + + let (mut websocket_writer, mut websocket_reader) = websocket_stream.split(); + loop { + tokio::select! { + outgoing_message = writer_rx.recv() => { + let Some(outgoing_message) = outgoing_message else { + break; + }; + let Some(json) = serialize_outgoing_message(outgoing_message) else { + continue; + }; + if websocket_writer.send(WebSocketMessage::Text(json.into())).await.is_err() { + break; + } + } + incoming_message = websocket_reader.next() => { + match incoming_message { + Some(Ok(WebSocketMessage::Text(text))) => { + if !forward_incoming_message( + &transport_event_tx, + &writer_tx_for_reader, + connection_id, + &text, + ) + .await + { + break; + } + } + Some(Ok(WebSocketMessage::Ping(payload))) => { + if websocket_writer.send(WebSocketMessage::Pong(payload)).await.is_err() { + break; + } + } + Some(Ok(WebSocketMessage::Pong(_))) => {} + Some(Ok(WebSocketMessage::Close(_))) | None => break, + Some(Ok(WebSocketMessage::Binary(_))) => { + warn!("dropping unsupported binary websocket message"); + } + Some(Ok(WebSocketMessage::Frame(_))) => {} + Some(Err(err)) => { + warn!("websocket receive error: {err}"); + break; + } + } + } + } + } + + let _ = transport_event_tx + .send(TransportEvent::ConnectionClosed { connection_id }) + .await; +} + +async fn forward_incoming_message( + transport_event_tx: &mpsc::Sender, + writer: &mpsc::Sender, + connection_id: ConnectionId, + payload: &str, +) -> bool { + match serde_json::from_str::(payload) { + Ok(message) => { + enqueue_incoming_message(transport_event_tx, writer, connection_id, message).await + } + Err(err) => { + error!("Failed to deserialize JSONRPCMessage: {err}"); + true + } + } +} + +async fn enqueue_incoming_message( + transport_event_tx: &mpsc::Sender, + writer: &mpsc::Sender, + connection_id: ConnectionId, + message: JSONRPCMessage, +) -> bool { + let event = TransportEvent::IncomingMessage { + connection_id, + message, + }; + match transport_event_tx.try_send(event) { + Ok(()) => true, + Err(mpsc::error::TrySendError::Closed(_)) => false, + Err(mpsc::error::TrySendError::Full(TransportEvent::IncomingMessage { + connection_id, + message: JSONRPCMessage::Request(request), + })) => { + let overload_error = OutgoingMessage::Error(OutgoingError { + id: request.id, + error: JSONRPCErrorError { + code: OVERLOADED_ERROR_CODE, + message: "Server overloaded; retry later.".to_string(), + data: None, + }, + }); + match writer.try_send(overload_error) { + Ok(()) => true, + Err(mpsc::error::TrySendError::Closed(_)) => false, + Err(mpsc::error::TrySendError::Full(_overload_error)) => { + warn!( + "dropping overload response for connection {:?}: outbound queue is full", + connection_id + ); + true + } + } + } + Err(mpsc::error::TrySendError::Full(event)) => transport_event_tx.send(event).await.is_ok(), + } +} + +fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option { + let value = match serde_json::to_value(outgoing_message) { + Ok(value) => value, + Err(err) => { + error!("Failed to convert OutgoingMessage to JSON value: {err}"); + return None; + } + }; + match serde_json::to_string(&value) { + Ok(json) => Some(json), + Err(err) => { + error!("Failed to serialize JSONRPCMessage: {err}"); + None + } + } +} + +fn should_skip_notification_for_connection( + connection_state: &OutboundConnectionState, + message: &OutgoingMessage, +) -> bool { + let Ok(opted_out_notification_methods) = connection_state.opted_out_notification_methods.read() + else { + warn!("failed to read outbound opted-out notifications"); + return false; + }; + match message { + OutgoingMessage::AppServerNotification(notification) => { + let method = notification.to_string(); + opted_out_notification_methods.contains(method.as_str()) + } + OutgoingMessage::Notification(notification) => { + opted_out_notification_methods.contains(notification.method.as_str()) + } + _ => false, + } +} + +pub(crate) async fn route_outgoing_envelope( + connections: &mut HashMap, + envelope: OutgoingEnvelope, +) -> Vec { + let mut disconnected = Vec::new(); + match envelope { + OutgoingEnvelope::ToConnection { + connection_id, + message, + } => { + let Some(connection_state) = connections.get(&connection_id) else { + warn!( + "dropping message for disconnected connection: {:?}", + connection_id + ); + return disconnected; + }; + if connection_state.writer.send(message).await.is_err() { + connections.remove(&connection_id); + disconnected.push(connection_id); + } + } + OutgoingEnvelope::Broadcast { message } => { + let target_connections: Vec = connections + .iter() + .filter_map(|(connection_id, connection_state)| { + if connection_state.initialized.load(Ordering::Acquire) + && !should_skip_notification_for_connection(connection_state, &message) + { + Some(*connection_id) + } else { + None + } + }) + .collect(); + + for connection_id in target_connections { + let Some(connection_state) = connections.get(&connection_id) else { + continue; + }; + if connection_state.writer.send(message.clone()).await.is_err() { + connections.remove(&connection_id); + disconnected.push(connection_id); + } + } + } + } + disconnected +} + +pub(crate) fn has_initialized_connections( + connections: &HashMap, +) -> bool { + connections + .values() + .any(|connection| connection.session.initialized) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error_code::OVERLOADED_ERROR_CODE; + use pretty_assertions::assert_eq; + use serde_json::json; + + #[test] + fn app_server_transport_parses_stdio_listen_url() { + let transport = AppServerTransport::from_listen_url(AppServerTransport::DEFAULT_LISTEN_URL) + .expect("stdio listen URL should parse"); + assert_eq!(transport, AppServerTransport::Stdio); + } + + #[test] + fn app_server_transport_parses_websocket_listen_url() { + let transport = AppServerTransport::from_listen_url("ws://127.0.0.1:1234") + .expect("websocket listen URL should parse"); + assert_eq!( + transport, + AppServerTransport::WebSocket { + bind_address: "127.0.0.1:1234".parse().expect("valid socket address"), + } + ); + } + + #[test] + fn app_server_transport_rejects_invalid_websocket_listen_url() { + let err = AppServerTransport::from_listen_url("ws://localhost:1234") + .expect_err("hostname bind address should be rejected"); + assert_eq!( + err.to_string(), + "invalid websocket --listen URL `ws://localhost:1234`; expected `ws://IP:PORT`" + ); + } + + #[test] + fn app_server_transport_rejects_unsupported_listen_url() { + let err = AppServerTransport::from_listen_url("http://127.0.0.1:1234") + .expect_err("unsupported scheme should fail"); + assert_eq!( + err.to_string(), + "unsupported --listen URL `http://127.0.0.1:1234`; expected `stdio://` or `ws://IP:PORT`" + ); + } + + #[tokio::test] + async fn enqueue_incoming_request_returns_overload_error_when_queue_is_full() { + let connection_id = ConnectionId(42); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + + let first_message = + JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + transport_event_tx + .send(TransportEvent::IncomingMessage { + connection_id, + message: first_message.clone(), + }) + .await + .expect("queue should accept first message"); + + let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { + id: codex_app_server_protocol::RequestId::Integer(7), + method: "config/read".to_string(), + params: Some(json!({ "includeLayers": false })), + }); + assert!( + enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request).await + ); + + let queued_event = transport_event_rx + .recv() + .await + .expect("first event should stay queued"); + match queued_event { + TransportEvent::IncomingMessage { + connection_id: queued_connection_id, + message, + } => { + assert_eq!(queued_connection_id, connection_id); + assert_eq!(message, first_message); + } + _ => panic!("expected queued incoming message"), + } + + let overload = writer_rx + .recv() + .await + .expect("request should receive overload error"); + let overload_json = serde_json::to_value(overload).expect("serialize overload error"); + assert_eq!( + overload_json, + json!({ + "id": 7, + "error": { + "code": OVERLOADED_ERROR_CODE, + "message": "Server overloaded; retry later." + } + }) + ); + } + + #[tokio::test] + async fn enqueue_incoming_response_waits_instead_of_dropping_when_queue_is_full() { + let connection_id = ConnectionId(42); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1); + let (writer_tx, _writer_rx) = mpsc::channel(1); + + let first_message = + JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + transport_event_tx + .send(TransportEvent::IncomingMessage { + connection_id, + message: first_message.clone(), + }) + .await + .expect("queue should accept first message"); + + let response = JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse { + id: codex_app_server_protocol::RequestId::Integer(7), + result: json!({"ok": true}), + }); + let transport_event_tx_for_enqueue = transport_event_tx.clone(); + let writer_tx_for_enqueue = writer_tx.clone(); + let enqueue_handle = tokio::spawn(async move { + enqueue_incoming_message( + &transport_event_tx_for_enqueue, + &writer_tx_for_enqueue, + connection_id, + response, + ) + .await + }); + + let queued_event = transport_event_rx + .recv() + .await + .expect("first event should be dequeued"); + match queued_event { + TransportEvent::IncomingMessage { + connection_id: queued_connection_id, + message, + } => { + assert_eq!(queued_connection_id, connection_id); + assert_eq!(message, first_message); + } + _ => panic!("expected queued incoming message"), + } + + let enqueue_result = enqueue_handle.await.expect("enqueue task should not panic"); + assert!(enqueue_result); + + let forwarded_event = transport_event_rx + .recv() + .await + .expect("response should be forwarded instead of dropped"); + match forwarded_event { + TransportEvent::IncomingMessage { + connection_id: queued_connection_id, + message: + JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse { id, result }), + } => { + assert_eq!(queued_connection_id, connection_id); + assert_eq!(id, codex_app_server_protocol::RequestId::Integer(7)); + assert_eq!(result, json!({"ok": true})); + } + _ => panic!("expected forwarded response message"), + } + } + + #[tokio::test] + async fn enqueue_incoming_request_does_not_block_when_writer_queue_is_full() { + let connection_id = ConnectionId(42); + let (transport_event_tx, _transport_event_rx) = mpsc::channel(1); + let (writer_tx, mut writer_rx) = mpsc::channel(1); + + transport_event_tx + .send(TransportEvent::IncomingMessage { + connection_id, + message: JSONRPCMessage::Notification( + codex_app_server_protocol::JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }, + ), + }) + .await + .expect("transport queue should accept first message"); + + writer_tx + .send(OutgoingMessage::Notification( + crate::outgoing_message::OutgoingNotification { + method: "queued".to_string(), + params: None, + }, + )) + .await + .expect("writer queue should accept first message"); + + let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { + id: codex_app_server_protocol::RequestId::Integer(7), + method: "config/read".to_string(), + params: Some(json!({ "includeLayers": false })), + }); + + let enqueue_result = tokio::time::timeout( + std::time::Duration::from_millis(100), + enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request), + ) + .await + .expect("enqueue should not block while writer queue is full"); + assert!(enqueue_result); + + let queued_outgoing = writer_rx + .recv() + .await + .expect("writer queue should still contain original message"); + let queued_json = serde_json::to_value(queued_outgoing).expect("serialize queued message"); + assert_eq!(queued_json, json!({ "method": "queued" })); + } +} diff --git a/codex-rs/app-server/tests/common/mcp_process.rs b/codex-rs/app-server/tests/common/mcp_process.rs index 7f77d8fc92a..d7ebc19e954 100644 --- a/codex-rs/app-server/tests/common/mcp_process.rs +++ b/codex-rs/app-server/tests/common/mcp_process.rs @@ -174,7 +174,7 @@ impl McpProcess { client_info, Some(InitializeCapabilities { experimental_api: true, - opt_out_notification_methods: None, + ..Default::default() }), ) .await diff --git a/codex-rs/app-server/tests/suite/v2/analytics.rs b/codex-rs/app-server/tests/suite/v2/analytics.rs index e18a0d3c849..0d05d644658 100644 --- a/codex-rs/app-server/tests/suite/v2/analytics.rs +++ b/codex-rs/app-server/tests/suite/v2/analytics.rs @@ -36,8 +36,9 @@ async fn app_server_default_analytics_disabled_without_flag() -> Result<()> { .map_err(|err| anyhow::anyhow!(err.to_string()))?; // With analytics unset in the config and the default flag is false, metrics are disabled. - // No provider is built. - assert_eq!(provider.is_none(), true); + // A provider may still exist for non-metrics telemetry, so check metrics specifically. + let has_metrics = provider.as_ref().and_then(|otel| otel.metrics()).is_some(); + assert_eq!(has_metrics, false); Ok(()) } diff --git a/codex-rs/app-server/tests/suite/v2/config_rpc.rs b/codex-rs/app-server/tests/suite/v2/config_rpc.rs index 4129564b16b..de4c51cde47 100644 --- a/codex-rs/app-server/tests/suite/v2/config_rpc.rs +++ b/codex-rs/app-server/tests/suite/v2/config_rpc.rs @@ -560,9 +560,22 @@ fn assert_layers_user_then_optional_system( layers: &[codex_app_server_protocol::ConfigLayer], user_file: AbsolutePathBuf, ) -> Result<()> { - assert_eq!(layers.len(), 2); - assert_eq!(layers[0].name, ConfigLayerSource::User { file: user_file }); - assert!(matches!(layers[1].name, ConfigLayerSource::System { .. })); + let mut first_index = 0; + if matches!( + layers.first().map(|layer| &layer.name), + Some(ConfigLayerSource::LegacyManagedConfigTomlFromMdm) + ) { + first_index = 1; + } + assert_eq!(layers.len(), first_index + 2); + assert_eq!( + layers[first_index].name, + ConfigLayerSource::User { file: user_file } + ); + assert!(matches!( + layers[first_index + 1].name, + ConfigLayerSource::System { .. } + )); Ok(()) } @@ -571,12 +584,25 @@ fn assert_layers_managed_user_then_optional_system( managed_file: AbsolutePathBuf, user_file: AbsolutePathBuf, ) -> Result<()> { - assert_eq!(layers.len(), 3); + let mut first_index = 0; + if matches!( + layers.first().map(|layer| &layer.name), + Some(ConfigLayerSource::LegacyManagedConfigTomlFromMdm) + ) { + first_index = 1; + } + assert_eq!(layers.len(), first_index + 3); assert_eq!( - layers[0].name, + layers[first_index].name, ConfigLayerSource::LegacyManagedConfigTomlFromFile { file: managed_file } ); - assert_eq!(layers[1].name, ConfigLayerSource::User { file: user_file }); - assert!(matches!(layers[2].name, ConfigLayerSource::System { .. })); + assert_eq!( + layers[first_index + 1].name, + ConfigLayerSource::User { file: user_file } + ); + assert!(matches!( + layers[first_index + 2].name, + ConfigLayerSource::System { .. } + )); Ok(()) } diff --git a/codex-rs/app-server/tests/suite/v2/connection_handling_websocket.rs b/codex-rs/app-server/tests/suite/v2/connection_handling_websocket.rs new file mode 100644 index 00000000000..ddd4326fc99 --- /dev/null +++ b/codex-rs/app-server/tests/suite/v2/connection_handling_websocket.rs @@ -0,0 +1,263 @@ +use anyhow::Context; +use anyhow::Result; +use anyhow::bail; +use app_test_support::create_mock_responses_server_sequence_unchecked; +use codex_app_server_protocol::ClientInfo; +use codex_app_server_protocol::InitializeParams; +use codex_app_server_protocol::JSONRPCError; +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCRequest; +use codex_app_server_protocol::JSONRPCResponse; +use codex_app_server_protocol::RequestId; +use futures::SinkExt; +use futures::StreamExt; +use serde_json::json; +use std::net::SocketAddr; +use std::path::Path; +use std::process::Stdio; +use tempfile::TempDir; +use tokio::io::AsyncBufReadExt; +use tokio::process::Child; +use tokio::process::Command; +use tokio::time::Duration; +use tokio::time::Instant; +use tokio::time::sleep; +use tokio::time::timeout; +use tokio_tungstenite::MaybeTlsStream; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::connect_async; +use tokio_tungstenite::tungstenite::Message as WebSocketMessage; + +const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5); + +type WsClient = WebSocketStream>; + +#[tokio::test] +async fn websocket_transport_routes_per_connection_handshake_and_responses() -> Result<()> { + let server = create_mock_responses_server_sequence_unchecked(Vec::new()).await; + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri(), "never")?; + + let bind_addr = reserve_local_addr()?; + let mut process = spawn_websocket_server(codex_home.path(), bind_addr).await?; + + let mut ws1 = connect_websocket(bind_addr).await?; + let mut ws2 = connect_websocket(bind_addr).await?; + + send_initialize_request(&mut ws1, 1, "ws_client_one").await?; + let first_init = read_response_for_id(&mut ws1, 1).await?; + assert_eq!(first_init.id, RequestId::Integer(1)); + + // Initialize responses are request-scoped and must not leak to other + // connections. + assert_no_message(&mut ws2, Duration::from_millis(250)).await?; + + send_config_read_request(&mut ws2, 2).await?; + let not_initialized = read_error_for_id(&mut ws2, 2).await?; + assert_eq!(not_initialized.error.message, "Not initialized"); + + send_initialize_request(&mut ws2, 3, "ws_client_two").await?; + let second_init = read_response_for_id(&mut ws2, 3).await?; + assert_eq!(second_init.id, RequestId::Integer(3)); + + // Same request-id on different connections must route independently. + send_config_read_request(&mut ws1, 77).await?; + send_config_read_request(&mut ws2, 77).await?; + let ws1_config = read_response_for_id(&mut ws1, 77).await?; + let ws2_config = read_response_for_id(&mut ws2, 77).await?; + + assert_eq!(ws1_config.id, RequestId::Integer(77)); + assert_eq!(ws2_config.id, RequestId::Integer(77)); + assert!(ws1_config.result.get("config").is_some()); + assert!(ws2_config.result.get("config").is_some()); + + process + .kill() + .await + .context("failed to stop websocket app-server process")?; + Ok(()) +} + +async fn spawn_websocket_server(codex_home: &Path, bind_addr: SocketAddr) -> Result { + let program = codex_utils_cargo_bin::cargo_bin("codex-app-server") + .context("should find app-server binary")?; + let mut cmd = Command::new(program); + cmd.arg("--listen") + .arg(format!("ws://{bind_addr}")) + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::piped()) + .env("CODEX_HOME", codex_home) + .env("RUST_LOG", "debug"); + let mut process = cmd + .kill_on_drop(true) + .spawn() + .context("failed to spawn websocket app-server process")?; + + if let Some(stderr) = process.stderr.take() { + let mut stderr_reader = tokio::io::BufReader::new(stderr).lines(); + tokio::spawn(async move { + while let Ok(Some(line)) = stderr_reader.next_line().await { + eprintln!("[websocket app-server stderr] {line}"); + } + }); + } + + Ok(process) +} + +fn reserve_local_addr() -> Result { + let listener = std::net::TcpListener::bind("127.0.0.1:0")?; + let addr = listener.local_addr()?; + drop(listener); + Ok(addr) +} + +async fn connect_websocket(bind_addr: SocketAddr) -> Result { + let url = format!("ws://{bind_addr}"); + let deadline = Instant::now() + Duration::from_secs(10); + loop { + match connect_async(&url).await { + Ok((stream, _response)) => return Ok(stream), + Err(err) => { + if Instant::now() >= deadline { + bail!("failed to connect websocket to {url}: {err}"); + } + sleep(Duration::from_millis(50)).await; + } + } + } +} + +async fn send_initialize_request(stream: &mut WsClient, id: i64, client_name: &str) -> Result<()> { + let params = InitializeParams { + client_info: ClientInfo { + name: client_name.to_string(), + title: Some("WebSocket Test Client".to_string()), + version: "0.1.0".to_string(), + }, + capabilities: None, + }; + send_request( + stream, + "initialize", + id, + Some(serde_json::to_value(params)?), + ) + .await +} + +async fn send_config_read_request(stream: &mut WsClient, id: i64) -> Result<()> { + send_request( + stream, + "config/read", + id, + Some(json!({ "includeLayers": false })), + ) + .await +} + +async fn send_request( + stream: &mut WsClient, + method: &str, + id: i64, + params: Option, +) -> Result<()> { + let message = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(id), + method: method.to_string(), + params, + }); + send_jsonrpc(stream, message).await +} + +async fn send_jsonrpc(stream: &mut WsClient, message: JSONRPCMessage) -> Result<()> { + let payload = serde_json::to_string(&message)?; + stream + .send(WebSocketMessage::Text(payload.into())) + .await + .context("failed to send websocket frame") +} + +async fn read_response_for_id(stream: &mut WsClient, id: i64) -> Result { + let target_id = RequestId::Integer(id); + loop { + let message = read_jsonrpc_message(stream).await?; + if let JSONRPCMessage::Response(response) = message + && response.id == target_id + { + return Ok(response); + } + } +} + +async fn read_error_for_id(stream: &mut WsClient, id: i64) -> Result { + let target_id = RequestId::Integer(id); + loop { + let message = read_jsonrpc_message(stream).await?; + if let JSONRPCMessage::Error(err) = message + && err.id == target_id + { + return Ok(err); + } + } +} + +async fn read_jsonrpc_message(stream: &mut WsClient) -> Result { + loop { + let frame = timeout(DEFAULT_READ_TIMEOUT, stream.next()) + .await + .context("timed out waiting for websocket frame")? + .context("websocket stream ended unexpectedly")? + .context("failed to read websocket frame")?; + + match frame { + WebSocketMessage::Text(text) => return Ok(serde_json::from_str(text.as_ref())?), + WebSocketMessage::Ping(payload) => { + stream.send(WebSocketMessage::Pong(payload)).await?; + } + WebSocketMessage::Pong(_) => {} + WebSocketMessage::Close(frame) => { + bail!("websocket closed unexpectedly: {frame:?}") + } + WebSocketMessage::Binary(_) => bail!("unexpected binary websocket frame"), + WebSocketMessage::Frame(_) => {} + } + } +} + +async fn assert_no_message(stream: &mut WsClient, wait_for: Duration) -> Result<()> { + match timeout(wait_for, stream.next()).await { + Ok(Some(Ok(frame))) => bail!("unexpected frame while waiting for silence: {frame:?}"), + Ok(Some(Err(err))) => bail!("unexpected websocket read error: {err}"), + Ok(None) => bail!("websocket closed unexpectedly while waiting for silence"), + Err(_) => Ok(()), + } +} + +fn create_config_toml( + codex_home: &Path, + server_uri: &str, + approval_policy: &str, +) -> std::io::Result<()> { + let config_toml = codex_home.join("config.toml"); + std::fs::write( + config_toml, + format!( + r#" +model = "mock-model" +approval_policy = "{approval_policy}" +sandbox_mode = "read-only" + +model_provider = "mock_provider" + +[model_providers.mock_provider] +name = "Mock provider for test" +base_url = "{server_uri}/v1" +wire_api = "responses" +request_max_retries = 0 +stream_max_retries = 0 +"# + ), + ) +} diff --git a/codex-rs/app-server/tests/suite/v2/mod.rs b/codex-rs/app-server/tests/suite/v2/mod.rs index e9e19395be4..48622acddbe 100644 --- a/codex-rs/app-server/tests/suite/v2/mod.rs +++ b/codex-rs/app-server/tests/suite/v2/mod.rs @@ -4,6 +4,7 @@ mod app_list; mod collaboration_mode_list; mod compaction; mod config_rpc; +mod connection_handling_websocket; mod dynamic_tools; mod experimental_api; mod experimental_feature_list; diff --git a/codex-rs/app-server/tests/suite/v2/review.rs b/codex-rs/app-server/tests/suite/v2/review.rs index 441ad2ce19f..7814950a631 100644 --- a/codex-rs/app-server/tests/suite/v2/review.rs +++ b/codex-rs/app-server/tests/suite/v2/review.rs @@ -5,8 +5,6 @@ use app_test_support::create_mock_responses_server_repeating_assistant; use app_test_support::create_mock_responses_server_sequence; use app_test_support::create_shell_command_sse_response; use app_test_support::to_response; -use codex_app_server_protocol::CommandExecutionApprovalDecision; -use codex_app_server_protocol::CommandExecutionRequestApprovalResponse; use codex_app_server_protocol::ItemCompletedNotification; use codex_app_server_protocol::ItemStartedNotification; use codex_app_server_protocol::JSONRPCError; @@ -211,9 +209,7 @@ async fn review_start_exec_approval_item_id_matches_command_execution_item() -> mcp.send_response( request_id, - serde_json::to_value(CommandExecutionRequestApprovalResponse { - decision: CommandExecutionApprovalDecision::Accept, - })?, + serde_json::json!({ "decision": codex_core::protocol::ReviewDecision::Approved }), ) .await?; timeout( diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index fa4dc5e3d92..494f8d89351 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -306,6 +306,15 @@ struct AppServerCommand { #[command(subcommand)] subcommand: Option, + /// Transport endpoint URL. Supported values: `stdio://` (default), + /// `ws://IP:PORT`. + #[arg( + long = "listen", + value_name = "URL", + default_value = codex_app_server::AppServerTransport::DEFAULT_LISTEN_URL + )] + listen: codex_app_server::AppServerTransport, + /// Controls whether analytics are enabled by default. /// /// Analytics are disabled by default for app-server. Users have to explicitly opt in @@ -587,11 +596,13 @@ async fn cli_main(codex_linux_sandbox_exe: Option) -> anyhow::Result<() } Some(Subcommand::AppServer(app_server_cli)) => match app_server_cli.subcommand { None => { - codex_app_server::run_main( + let transport = app_server_cli.listen; + codex_app_server::run_main_with_transport( codex_linux_sandbox_exe, root_config_overrides, codex_core::config_loader::LoaderOverrides::default(), app_server_cli.analytics_default_enabled, + transport, ) .await?; } @@ -1328,6 +1339,10 @@ mod tests { fn app_server_analytics_default_disabled_without_flag() { let app_server = app_server_from_args(["codex", "app-server"].as_ref()); assert!(!app_server.analytics_default_enabled); + assert_eq!( + app_server.listen, + codex_app_server::AppServerTransport::Stdio + ); } #[test] @@ -1337,6 +1352,36 @@ mod tests { assert!(app_server.analytics_default_enabled); } + #[test] + fn app_server_listen_websocket_url_parses() { + let app_server = app_server_from_args( + ["codex", "app-server", "--listen", "ws://127.0.0.1:4500"].as_ref(), + ); + assert_eq!( + app_server.listen, + codex_app_server::AppServerTransport::WebSocket { + bind_address: "127.0.0.1:4500".parse().expect("valid socket address"), + } + ); + } + + #[test] + fn app_server_listen_stdio_url_parses() { + let app_server = + app_server_from_args(["codex", "app-server", "--listen", "stdio://"].as_ref()); + assert_eq!( + app_server.listen, + codex_app_server::AppServerTransport::Stdio + ); + } + + #[test] + fn app_server_listen_invalid_url_fails_to_parse() { + let parse_result = + MultitoolCli::try_parse_from(["codex", "app-server", "--listen", "http://foo"]); + assert!(parse_result.is_err()); + } + #[test] fn features_enable_parses_feature_name() { let cli = MultitoolCli::try_parse_from(["codex", "features", "enable", "unified_exec"]) diff --git a/codex-rs/core/tests/suite/review.rs b/codex-rs/core/tests/suite/review.rs index a7010ecaf1d..1c9c3adf7b0 100644 --- a/codex-rs/core/tests/suite/review.rs +++ b/codex-rs/core/tests/suite/review.rs @@ -371,25 +371,6 @@ async fn review_does_not_emit_agent_message_on_structured_output() { _ => false, }) .await; - // On slower CI hosts, the final AgentMessage can arrive immediately after - // TurnComplete. Drain a brief tail window to make ordering nondeterminism - // harmless while still enforcing "exactly one final AgentMessage". - while let Ok(Ok(event)) = - tokio::time::timeout(std::time::Duration::from_millis(200), codex.next_event()).await - { - match event.msg { - EventMsg::AgentMessage(_) => agent_messages += 1, - EventMsg::EnteredReviewMode(_) => saw_entered = true, - EventMsg::ExitedReviewMode(_) => saw_exited = true, - EventMsg::AgentMessageContentDelta(_) => { - panic!("unexpected AgentMessageContentDelta surfaced during review") - } - EventMsg::AgentMessageDelta(_) => { - panic!("unexpected AgentMessageDelta surfaced during review") - } - _ => {} - } - } assert_eq!(1, agent_messages, "expected exactly one AgentMessage event"); assert!(saw_entered && saw_exited, "missing review lifecycle events");