diff --git a/codex-rs/codex-api/src/endpoint/responses_websocket.rs b/codex-rs/codex-api/src/endpoint/responses_websocket.rs index b4871153df9..175507f27c4 100644 --- a/codex-rs/codex-api/src/endpoint/responses_websocket.rs +++ b/codex-rs/codex-api/src/endpoint/responses_websocket.rs @@ -26,6 +26,7 @@ use std::time::Duration; use tokio::net::TcpStream; use tokio::sync::Mutex; use tokio::sync::mpsc; +use tokio::sync::oneshot; use tokio::time::Instant; use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::WebSocketStream; @@ -41,7 +42,124 @@ use tungstenite::extensions::compression::deflate::DeflateConfig; use tungstenite::protocol::WebSocketConfig; use url::Url; -type WsStream = WebSocketStream>; +struct WsStream { + tx_command: mpsc::Sender, + rx_message: mpsc::UnboundedReceiver>, + pump_task: tokio::task::JoinHandle<()>, +} + +enum WsCommand { + Send { + message: Message, + tx_result: oneshot::Sender>, + }, + Close { + tx_result: oneshot::Sender>, + }, +} + +impl WsStream { + fn new(inner: WebSocketStream>) -> Self { + let (tx_command, mut rx_command) = mpsc::channel::(32); + let (tx_message, rx_message) = mpsc::unbounded_channel::>(); + + let pump_task = tokio::spawn(async move { + let mut inner = inner; + loop { + tokio::select! { + command = rx_command.recv() => { + let Some(command) = command else { + break; + }; + match command { + WsCommand::Send { message, tx_result } => { + let result = inner.send(message).await; + let should_break = result.is_err(); + let _ = tx_result.send(result); + if should_break { + break; + } + } + WsCommand::Close { tx_result } => { + let result = inner.close(None).await; + let _ = tx_result.send(result); + break; + } + } + } + message = inner.next() => { + let Some(message) = message else { + break; + }; + match message { + Ok(Message::Ping(payload)) => { + if let Err(err) = inner.send(Message::Pong(payload)).await { + let _ = tx_message.send(Err(err)); + break; + } + } + Ok(Message::Pong(_)) => {} + Ok(message @ (Message::Text(_) + | Message::Binary(_) + | Message::Close(_) + | Message::Frame(_))) => { + let is_close = matches!(message, Message::Close(_)); + if tx_message.send(Ok(message)).is_err() { + break; + } + if is_close { + break; + } + } + Err(err) => { + let _ = tx_message.send(Err(err)); + break; + } + } + } + } + } + }); + + Self { + tx_command, + rx_message, + pump_task, + } + } + + async fn request( + &self, + make_command: impl FnOnce(oneshot::Sender>) -> WsCommand, + ) -> Result<(), WsError> { + let (tx_result, rx_result) = oneshot::channel(); + if self.tx_command.send(make_command(tx_result)).await.is_err() { + return Err(WsError::ConnectionClosed); + } + rx_result.await.unwrap_or(Err(WsError::ConnectionClosed)) + } + + async fn send(&self, message: Message) -> Result<(), WsError> { + self.request(|tx_result| WsCommand::Send { message, tx_result }) + .await + } + + async fn close(&self) -> Result<(), WsError> { + self.request(|tx_result| WsCommand::Close { tx_result }) + .await + } + + async fn next(&mut self) -> Option> { + self.rx_message.recv().await + } +} + +impl Drop for WsStream { + fn drop(&mut self) { + self.pump_task.abort(); + } +} + const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state"; const X_MODELS_ETAG_HEADER: &str = "x-models-etag"; const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included"; @@ -119,7 +237,7 @@ impl ResponsesWebsocketConnection { ) .await { - let _ = ws_stream.close(None).await; + let _ = ws_stream.close().await; *guard = None; let _ = tx_event.send(Err(err)).await; } @@ -215,7 +333,7 @@ async fn connect_websocket( { let _ = turn_state.set(header_value.to_string()); } - Ok((stream, reasoning_included, models_etag)) + Ok((WsStream::new(stream), reasoning_included, models_etag)) } fn websocket_config() -> WebSocketConfig { @@ -419,18 +537,13 @@ async fn run_websocket_response_stream( Message::Binary(_) => { return Err(ApiError::Stream("unexpected binary websocket event".into())); } - Message::Ping(payload) => { - if ws_stream.send(Message::Pong(payload)).await.is_err() { - return Err(ApiError::Stream("websocket ping failed".into())); - } - } - Message::Pong(_) => {} Message::Close(_) => { return Err(ApiError::Stream( "websocket closed by server before response.completed".into(), )); } - _ => {} + Message::Frame(_) => {} + Message::Ping(_) | Message::Pong(_) => {} } }