Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 123 additions & 10 deletions codex-rs/codex-api/src/endpoint/responses_websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::time::Instant;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream;
Expand All @@ -41,7 +42,124 @@ use tungstenite::extensions::compression::deflate::DeflateConfig;
use tungstenite::protocol::WebSocketConfig;
use url::Url;

type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
struct WsStream {
tx_command: mpsc::Sender<WsCommand>,
rx_message: mpsc::UnboundedReceiver<Result<Message, WsError>>,
pump_task: tokio::task::JoinHandle<()>,
}

enum WsCommand {
Send {
message: Message,
tx_result: oneshot::Sender<Result<(), WsError>>,
},
Close {
tx_result: oneshot::Sender<Result<(), WsError>>,
},
}

impl WsStream {
fn new(inner: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
let (tx_command, mut rx_command) = mpsc::channel::<WsCommand>(32);
let (tx_message, rx_message) = mpsc::unbounded_channel::<Result<Message, WsError>>();

let pump_task = tokio::spawn(async move {
let mut inner = inner;
loop {
tokio::select! {
command = rx_command.recv() => {
let Some(command) = command else {
break;
};
match command {
WsCommand::Send { message, tx_result } => {
let result = inner.send(message).await;
let should_break = result.is_err();
let _ = tx_result.send(result);
if should_break {
break;
}
}
WsCommand::Close { tx_result } => {
let result = inner.close(None).await;
let _ = tx_result.send(result);
break;
}
}
}
message = inner.next() => {
let Some(message) = message else {
break;
};
match message {
Ok(Message::Ping(payload)) => {
if let Err(err) = inner.send(Message::Pong(payload)).await {
let _ = tx_message.send(Err(err));
break;
}
}
Ok(Message::Pong(_)) => {}
Ok(message @ (Message::Text(_)
| Message::Binary(_)
| Message::Close(_)
| Message::Frame(_))) => {
let is_close = matches!(message, Message::Close(_));
if tx_message.send(Ok(message)).is_err() {
break;
}
if is_close {
break;
}
}
Err(err) => {
let _ = tx_message.send(Err(err));
break;
}
}
}
}
}
});

Self {
tx_command,
rx_message,
pump_task,
}
}

async fn request(
&self,
make_command: impl FnOnce(oneshot::Sender<Result<(), WsError>>) -> WsCommand,
) -> Result<(), WsError> {
let (tx_result, rx_result) = oneshot::channel();
if self.tx_command.send(make_command(tx_result)).await.is_err() {
return Err(WsError::ConnectionClosed);
}
rx_result.await.unwrap_or(Err(WsError::ConnectionClosed))
}

async fn send(&self, message: Message) -> Result<(), WsError> {
self.request(|tx_result| WsCommand::Send { message, tx_result })
.await
}

async fn close(&self) -> Result<(), WsError> {
self.request(|tx_result| WsCommand::Close { tx_result })
.await
}

async fn next(&mut self) -> Option<Result<Message, WsError>> {
self.rx_message.recv().await
}
}

impl Drop for WsStream {
fn drop(&mut self) {
self.pump_task.abort();
}
}

const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
const X_MODELS_ETAG_HEADER: &str = "x-models-etag";
const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included";
Expand Down Expand Up @@ -119,7 +237,7 @@ impl ResponsesWebsocketConnection {
)
.await
{
let _ = ws_stream.close(None).await;
let _ = ws_stream.close().await;
*guard = None;
let _ = tx_event.send(Err(err)).await;
}
Expand Down Expand Up @@ -215,7 +333,7 @@ async fn connect_websocket(
{
let _ = turn_state.set(header_value.to_string());
}
Ok((stream, reasoning_included, models_etag))
Ok((WsStream::new(stream), reasoning_included, models_etag))
}

fn websocket_config() -> WebSocketConfig {
Expand Down Expand Up @@ -419,18 +537,13 @@ async fn run_websocket_response_stream(
Message::Binary(_) => {
return Err(ApiError::Stream("unexpected binary websocket event".into()));
}
Message::Ping(payload) => {
if ws_stream.send(Message::Pong(payload)).await.is_err() {
return Err(ApiError::Stream("websocket ping failed".into()));
}
}
Message::Pong(_) => {}
Message::Close(_) => {
return Err(ApiError::Stream(
"websocket closed by server before response.completed".into(),
));
}
_ => {}
Message::Frame(_) => {}
Message::Ping(_) | Message::Pong(_) => {}
}
}

Expand Down
Loading