Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions codex-rs/codex-api/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ pub enum ResponseEvent {
Completed {
response_id: String,
token_usage: Option<TokenUsage>,
/// Whether the client can append more items to a long-running websocket response.
can_append: bool,
},
OutputTextDelta(String),
ReasoningSummaryDelta {
Expand Down
3 changes: 3 additions & 0 deletions codex-rs/codex-api/src/endpoint/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ impl Stream for AggregatedStream {
Poll::Ready(Some(Ok(ResponseEvent::Completed {
response_id,
token_usage,
can_append: _can_append,
}))) => {
let mut emitted_any = false;

Expand Down Expand Up @@ -102,6 +103,7 @@ impl Stream for AggregatedStream {
this.pending.push_back(ResponseEvent::Completed {
response_id: response_id.clone(),
token_usage: token_usage.clone(),
can_append: false,
});
if let Some(ev) = this.pending.pop_front() {
return Poll::Ready(Some(Ok(ev)));
Expand All @@ -111,6 +113,7 @@ impl Stream for AggregatedStream {
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
response_id,
token_usage,
can_append: false,
})));
}
Poll::Ready(Some(Ok(ResponseEvent::Created))) => continue,
Expand Down
1 change: 1 addition & 0 deletions codex-rs/codex-api/src/endpoint/responses_websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ async fn run_websocket_response_stream(
)));
}
};
trace!("websocket request: {request_text}");

let request_start = Instant::now();
let result = ws_stream
Expand Down
13 changes: 12 additions & 1 deletion codex-rs/codex-api/src/sse/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ pub fn process_responses_event(
return Ok(Some(ResponseEvent::Completed {
response_id: resp.id,
token_usage: resp.usage.map(Into::into),
can_append: false,
}));
}
Err(err) => {
Expand All @@ -276,6 +277,7 @@ pub fn process_responses_event(
return Ok(Some(ResponseEvent::Completed {
response_id: resp.id.unwrap_or_default(),
token_usage: resp.usage.map(Into::into),
can_append: true,
}));
}
Err(err) => {
Expand All @@ -290,6 +292,7 @@ pub fn process_responses_event(
return Ok(Some(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
can_append: true,
}));
}
"response.output_item.added" => {
Expand Down Expand Up @@ -548,9 +551,11 @@ mod tests {
Ok(ResponseEvent::Completed {
response_id,
token_usage,
can_append,
}) => {
assert_eq!(response_id, "resp1");
assert!(token_usage.is_none());
assert!(!can_append);
}
other => panic!("unexpected third event: {other:?}"),
}
Expand Down Expand Up @@ -585,7 +590,7 @@ mod tests {
}

#[tokio::test]
async fn response_done_emits_completed() {
async fn response_done_emits_incremental_completed() {
let done = json!({
"type": "response.done",
"response": {
Expand All @@ -610,9 +615,11 @@ mod tests {
Ok(ResponseEvent::Completed {
response_id,
token_usage,
can_append,
}) => {
assert_eq!(response_id, "");
assert!(token_usage.is_some());
assert!(*can_append);
}
other => panic!("unexpected event: {other:?}"),
}
Expand All @@ -635,9 +642,11 @@ mod tests {
Ok(ResponseEvent::Completed {
response_id,
token_usage,
can_append,
}) => {
assert_eq!(response_id, "");
assert!(token_usage.is_none());
assert!(*can_append);
}
other => panic!("unexpected event: {other:?}"),
}
Expand Down Expand Up @@ -673,9 +682,11 @@ mod tests {
Ok(ResponseEvent::Completed {
response_id,
token_usage,
can_append,
}) => {
assert_eq!(response_id, "resp1");
assert!(token_usage.is_none());
assert!(!can_append);
}
other => panic!("unexpected event: {other:?}"),
}
Expand Down
2 changes: 2 additions & 0 deletions codex-rs/codex-api/tests/sse_end_to_end.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,11 @@ async fn responses_stream_parses_items_and_completed_end_to_end() -> Result<()>
ResponseEvent::Completed {
response_id,
token_usage,
can_append,
} => {
assert_eq!(response_id, "resp1");
assert!(token_usage.is_none());
assert!(!can_append);
}
other => panic!("unexpected third event: {other:?}"),
}
Expand Down
28 changes: 19 additions & 9 deletions codex-rs/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ use tokio::sync::oneshot;
use tokio::sync::oneshot::error::TryRecvError;
use tokio_tungstenite::tungstenite::Error;
use tokio_tungstenite::tungstenite::Message;
use tracing::trace;
use tracing::warn;

use crate::AuthManager;
Expand Down Expand Up @@ -185,6 +186,7 @@ pub struct ModelClientSession {
struct LastResponse {
response_id: String,
items_added: Vec<ResponseItem>,
can_append: bool,
}

enum WebsocketStreamOutcome {
Expand Down Expand Up @@ -550,6 +552,9 @@ impl ModelClientSession {
let mut request_without_input = request.clone();
request_without_input.input.clear();
if previous_without_input != request_without_input {
trace!(
"incremental request failed, properties didn't match {previous_without_input:?} != {request_without_input:?}"
);
return None;
}

Expand All @@ -565,6 +570,7 @@ impl ModelClientSession {
{
Some(request.input[baseline_len..].to_vec())
} else {
trace!("incremental request failed, items didn't match");
None
}
}
Expand All @@ -583,18 +589,19 @@ impl ModelClientSession {
payload: ResponseCreateWsRequest,
request: &ResponsesApiRequest,
) -> ResponsesWsRequest {
let last_response = self.get_last_response();
let Some(last_response) = self.get_last_response() else {
return ResponsesWsRequest::ResponseCreate(payload);
};
let responses_websockets_v2_enabled = self.client.responses_websockets_v2_enabled();
let incremental_items = self.get_incremental_items(request, last_response.as_ref());
if !responses_websockets_v2_enabled && !last_response.can_append {
trace!("incremental request failed, can't append");
return ResponsesWsRequest::ResponseCreate(payload);
}
let incremental_items = self.get_incremental_items(request, Some(&last_response));
if let Some(append_items) = incremental_items {
if responses_websockets_v2_enabled
&& let Some(previous_response_id) = last_response
.as_ref()
.map(|last_response| last_response.response_id.clone())
.filter(|id| !id.is_empty())
{
if responses_websockets_v2_enabled && !last_response.response_id.is_empty() {
let payload = ResponseCreateWsRequest {
previous_response_id: Some(previous_response_id),
previous_response_id: Some(last_response.response_id),
input: append_items,
..payload
};
Expand Down Expand Up @@ -1014,6 +1021,7 @@ where
Ok(ResponseEvent::Completed {
response_id,
token_usage,
can_append,
}) => {
if let Some(usage) = &token_usage {
otel_manager.sse_event_completed(
Expand All @@ -1028,12 +1036,14 @@ where
let _ = sender.send(LastResponse {
response_id: response_id.clone(),
items_added: std::mem::take(&mut items_added),
can_append,
});
}
if tx_event
.send(Ok(ResponseEvent::Completed {
response_id,
token_usage,
can_append,
}))
.await
.is_err()
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4990,6 +4990,7 @@ async fn try_run_sampling_request(
ResponseEvent::Completed {
response_id: _,
token_usage,
can_append: _,
} => {
if let Some(state) = plan_mode_state.as_mut() {
flush_proposed_plan_segments_all(&sess, &turn_context, state).await;
Expand Down
10 changes: 10 additions & 0 deletions codex-rs/core/tests/common/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,16 @@ pub fn ev_done() -> Value {
})
}

pub fn ev_done_with_id(id: &str) -> Value {
serde_json::json!({
"type": "response.done",
"response": {
"id": id,
"usage": {"input_tokens":0,"input_tokens_details":null,"output_tokens":0,"output_tokens_details":null,"total_tokens":0}
}
})
}

/// Convenience: SSE event for a created response with a specific id.
pub fn ev_response_created(id: &str) -> Value {
serde_json::json!({
Expand Down
45 changes: 43 additions & 2 deletions codex-rs/core/tests/suite/client_websockets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ use core_test_support::responses::WebSocketConnectionConfig;
use core_test_support::responses::WebSocketTestServer;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_done;
use core_test_support::responses::ev_done_with_id;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::start_websocket_server;
use core_test_support::responses::start_websocket_server_with_headers;
Expand Down Expand Up @@ -574,7 +576,7 @@ async fn responses_websocket_appends_on_prefix() {
vec![
ev_response_created("resp-1"),
ev_assistant_message("msg-1", "assistant output"),
ev_completed("resp-1"),
ev_done(),
],
vec![ev_response_created("resp-2"), ev_completed("resp-2")],
]])
Expand Down Expand Up @@ -610,6 +612,45 @@ async fn responses_websocket_appends_on_prefix() {
server.shutdown().await;
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_creates_on_prefix_when_previous_completion_cannot_append() {
skip_if_no_network!();

let server = start_websocket_server(vec![vec![
vec![
ev_response_created("resp-1"),
ev_assistant_message("msg-1", "assistant output"),
ev_completed("resp-1"),
],
vec![ev_response_created("resp-2"), ev_completed("resp-2")],
]])
.await;

let harness = websocket_harness(&server).await;
let mut client_session = harness.client.new_session();
let prompt_one = prompt_with_input(vec![message_item("hello")]);
let prompt_two = prompt_with_input(vec![
message_item("hello"),
assistant_message_item("msg-1", "assistant output"),
message_item("second"),
]);

stream_until_complete(&mut client_session, &harness, &prompt_one).await;
stream_until_complete(&mut client_session, &harness, &prompt_two).await;

let connection = server.single_connection();
assert_eq!(connection.len(), 2);
let second = connection.get(1).expect("missing request").body_json();

assert_eq!(second["type"].as_str(), Some("response.create"));
assert_eq!(
second["input"],
serde_json::to_value(&prompt_two.input).expect("serialize full input")
);

server.shutdown().await;
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_creates_on_non_prefix() {
skip_if_no_network!();
Expand Down Expand Up @@ -687,7 +728,7 @@ async fn responses_websocket_v2_creates_with_previous_response_id_on_prefix() {
vec![
ev_response_created("resp-1"),
ev_assistant_message("msg-1", "assistant output"),
ev_completed("resp-1"),
ev_done_with_id("resp-1"),
],
vec![ev_response_created("resp-2"), ev_completed("resp-2")],
]])
Expand Down
Loading