diff --git a/codex-rs/app-server/tests/common/mcp_process.rs b/codex-rs/app-server/tests/common/mcp_process.rs index f3ec682fb21..af130e84603 100644 --- a/codex-rs/app-server/tests/common/mcp_process.rs +++ b/codex-rs/app-server/tests/common/mcp_process.rs @@ -60,7 +60,7 @@ pub struct McpProcess { process: Child, stdin: ChildStdin, stdout: BufReader, - pending_user_messages: VecDeque, + pending_messages: VecDeque, } impl McpProcess { @@ -127,7 +127,7 @@ impl McpProcess { process, stdin, stdout, - pending_user_messages: VecDeque::new(), + pending_messages: VecDeque::new(), }) } @@ -544,27 +544,16 @@ impl McpProcess { pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result { eprintln!("in read_stream_until_request_message()"); - loop { - let message = self.read_jsonrpc_message().await?; + let message = self + .read_stream_until_message(|message| matches!(message, JSONRPCMessage::Request(_))) + .await?; - match message { - JSONRPCMessage::Notification(notification) => { - eprintln!("notification: {notification:?}"); - self.enqueue_user_message(notification); - } - JSONRPCMessage::Request(jsonrpc_request) => { - return jsonrpc_request.try_into().with_context( - || "failed to deserialize ServerRequest from JSONRPCRequest", - ); - } - JSONRPCMessage::Error(_) => { - anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}"); - } - JSONRPCMessage::Response(_) => { - anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}"); - } - } - } + let JSONRPCMessage::Request(jsonrpc_request) = message else { + unreachable!("expected JSONRPCMessage::Request, got {message:?}"); + }; + jsonrpc_request + .try_into() + .with_context(|| "failed to deserialize ServerRequest from JSONRPCRequest") } pub async fn read_stream_until_response_message( @@ -573,52 +562,32 @@ impl McpProcess { ) -> anyhow::Result { eprintln!("in read_stream_until_response_message({request_id:?})"); - loop { - let message = self.read_jsonrpc_message().await?; - match message { - JSONRPCMessage::Notification(notification) => { - eprintln!("notification: {notification:?}"); - self.enqueue_user_message(notification); - } - JSONRPCMessage::Request(_) => { - anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}"); - } - JSONRPCMessage::Error(_) => { - anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}"); - } - JSONRPCMessage::Response(jsonrpc_response) => { - if jsonrpc_response.id == request_id { - return Ok(jsonrpc_response); - } - } - } - } + let message = self + .read_stream_until_message(|message| { + Self::message_request_id(message) == Some(&request_id) + }) + .await?; + + let JSONRPCMessage::Response(response) = message else { + unreachable!("expected JSONRPCMessage::Response, got {message:?}"); + }; + Ok(response) } pub async fn read_stream_until_error_message( &mut self, request_id: RequestId, ) -> anyhow::Result { - loop { - let message = self.read_jsonrpc_message().await?; - match message { - JSONRPCMessage::Notification(notification) => { - eprintln!("notification: {notification:?}"); - self.enqueue_user_message(notification); - } - JSONRPCMessage::Request(_) => { - anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}"); - } - JSONRPCMessage::Response(_) => { - // Keep scanning; we're waiting for an error with matching id. - } - JSONRPCMessage::Error(err) => { - if err.id == request_id { - return Ok(err); - } - } - } - } + let message = self + .read_stream_until_message(|message| { + Self::message_request_id(message) == Some(&request_id) + }) + .await?; + + let JSONRPCMessage::Error(err) = message else { + unreachable!("expected JSONRPCMessage::Error, got {message:?}"); + }; + Ok(err) } pub async fn read_stream_until_notification_message( @@ -627,46 +596,64 @@ impl McpProcess { ) -> anyhow::Result { eprintln!("in read_stream_until_notification_message({method})"); - if let Some(notification) = self.take_pending_notification_by_method(method) { - return Ok(notification); + let message = self + .read_stream_until_message(|message| { + matches!( + message, + JSONRPCMessage::Notification(notification) if notification.method == method + ) + }) + .await?; + + let JSONRPCMessage::Notification(notification) = message else { + unreachable!("expected JSONRPCMessage::Notification, got {message:?}"); + }; + Ok(notification) + } + + /// Clears any buffered messages so future reads only consider new stream items. + /// + /// We call this when e.g. we want to validate against the next turn and no longer care about + /// messages buffered from the prior turn. + pub fn clear_message_buffer(&mut self) { + self.pending_messages.clear(); + } + + /// Reads the stream until a message matches `predicate`, buffering any non-matching messages + /// for later reads. + async fn read_stream_until_message(&mut self, predicate: F) -> anyhow::Result + where + F: Fn(&JSONRPCMessage) -> bool, + { + if let Some(message) = self.take_pending_message(&predicate) { + return Ok(message); } loop { let message = self.read_jsonrpc_message().await?; - match message { - JSONRPCMessage::Notification(notification) => { - if notification.method == method { - return Ok(notification); - } - self.enqueue_user_message(notification); - } - JSONRPCMessage::Request(_) => { - anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}"); - } - JSONRPCMessage::Error(_) => { - anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}"); - } - JSONRPCMessage::Response(_) => { - anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}"); - } + if predicate(&message) { + return Ok(message); } + self.pending_messages.push_back(message); } } - fn take_pending_notification_by_method(&mut self, method: &str) -> Option { - if let Some(pos) = self - .pending_user_messages - .iter() - .position(|notification| notification.method == method) - { - return self.pending_user_messages.remove(pos); + fn take_pending_message(&mut self, predicate: &F) -> Option + where + F: Fn(&JSONRPCMessage) -> bool, + { + if let Some(pos) = self.pending_messages.iter().position(predicate) { + return self.pending_messages.remove(pos); } None } - fn enqueue_user_message(&mut self, notification: JSONRPCNotification) { - if notification.method == "codex/event/user_message" { - self.pending_user_messages.push_back(notification); + fn message_request_id(message: &JSONRPCMessage) -> Option<&RequestId> { + match message { + JSONRPCMessage::Request(request) => Some(&request.id), + JSONRPCMessage::Response(response) => Some(&response.id), + JSONRPCMessage::Error(err) => Some(&err.id), + JSONRPCMessage::Notification(_) => None, } } } diff --git a/codex-rs/app-server/tests/suite/codex_message_processor_flow.rs b/codex-rs/app-server/tests/suite/codex_message_processor_flow.rs index a508bf88057..1dcb917f085 100644 --- a/codex-rs/app-server/tests/suite/codex_message_processor_flow.rs +++ b/codex-rs/app-server/tests/suite/codex_message_processor_flow.rs @@ -430,6 +430,7 @@ async fn test_send_user_turn_updates_sandbox_and_cwd_between_turns() -> Result<( mcp.read_stream_until_notification_message("codex/event/task_complete"), ) .await??; + mcp.clear_message_buffer(); let second_turn_id = mcp .send_send_user_turn_request(SendUserTurnParams { diff --git a/codex-rs/app-server/tests/suite/send_message.rs b/codex-rs/app-server/tests/suite/send_message.rs index ed93f8a7f3e..f57b5f2ee4a 100644 --- a/codex-rs/app-server/tests/suite/send_message.rs +++ b/codex-rs/app-server/tests/suite/send_message.rs @@ -1,7 +1,5 @@ use anyhow::Result; use app_test_support::McpProcess; -use app_test_support::create_final_assistant_message_sse_response; -use app_test_support::create_mock_chat_completions_server; use app_test_support::to_response; use codex_app_server_protocol::AddConversationListenerParams; use codex_app_server_protocol::AddConversationSubscriptionResponse; @@ -17,6 +15,7 @@ use codex_protocol::ThreadId; use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; use codex_protocol::protocol::RawResponseItemEvent; +use core_test_support::responses; use pretty_assertions::assert_eq; use std::path::Path; use tempfile::TempDir; @@ -26,13 +25,21 @@ const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs #[tokio::test] async fn test_send_message_success() -> Result<()> { - // Spin up a mock completions server that immediately ends the Codex turn. + // Spin up a mock responses server that immediately ends the Codex turn. // Two Codex turns hit the mock model (session start + send-user-message). Provide two SSE responses. - let responses = vec![ - create_final_assistant_message_sse_response("Done")?, - create_final_assistant_message_sse_response("Done")?, - ]; - let server = create_mock_chat_completions_server(responses).await; + let server = responses::start_mock_server().await; + let body1 = responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_assistant_message("msg-1", "Done"), + responses::ev_completed("resp-1"), + ]); + let body2 = responses::sse(vec![ + responses::ev_response_created("resp-2"), + responses::ev_assistant_message("msg-2", "Done"), + responses::ev_completed("resp-2"), + ]); + let _response_mock1 = responses::mount_sse_once(&server, body1).await; + let _response_mock2 = responses::mount_sse_once(&server, body2).await; // Create a temporary Codex home with config pointing at the mock server. let codex_home = TempDir::new()?; @@ -135,8 +142,13 @@ async fn send_message( #[tokio::test] async fn test_send_message_raw_notifications_opt_in() -> Result<()> { - let responses = vec![create_final_assistant_message_sse_response("Done")?]; - let server = create_mock_chat_completions_server(responses).await; + let server = responses::start_mock_server().await; + let body = responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_assistant_message("msg-1", "Done"), + responses::ev_completed("resp-1"), + ]); + let _response_mock = responses::mount_sse_once(&server, body).await; let codex_home = TempDir::new()?; create_config_toml(codex_home.path(), &server.uri())?; @@ -259,7 +271,7 @@ model_provider = "mock_provider" [model_providers.mock_provider] name = "Mock provider for test" base_url = "{server_uri}/v1" -wire_api = "chat" +wire_api = "responses" request_max_retries = 0 stream_max_retries = 0 "# @@ -269,6 +281,7 @@ stream_max_retries = 0 #[expect(clippy::expect_used)] async fn read_raw_response_item(mcp: &mut McpProcess, conversation_id: ThreadId) -> ResponseItem { + // TODO: Switch to rawResponseItem/completed once we migrate to app server v2 in codex web. loop { let raw_notification: JSONRPCNotification = timeout( DEFAULT_READ_TIMEOUT, diff --git a/codex-rs/app-server/tests/suite/v2/turn_start.rs b/codex-rs/app-server/tests/suite/v2/turn_start.rs index ab450ea832c..d992eeee4b2 100644 --- a/codex-rs/app-server/tests/suite/v2/turn_start.rs +++ b/codex-rs/app-server/tests/suite/v2/turn_start.rs @@ -554,6 +554,7 @@ async fn turn_start_updates_sandbox_and_cwd_between_turns_v2() -> Result<()> { mcp.read_stream_until_notification_message("codex/event/task_complete"), ) .await??; + mcp.clear_message_buffer(); // second turn with workspace-write and second_cwd, ensure exec begins in second_cwd let second_turn = mcp