Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions codex-rs/codex-api/src/endpoint/responses_websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::common::ResponseStream;
use crate::common::ResponsesWsRequest;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::rate_limits::parse_rate_limit_event;
use crate::sse::responses::ResponsesStreamEvent;
use crate::sse::responses::process_responses_event;
use crate::telemetry::WebsocketTelemetry;
Expand Down Expand Up @@ -33,13 +34,15 @@ use url::Url;

type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
const X_MODELS_ETAG_HEADER: &str = "x-models-etag";
const X_REASONING_INCLUDED_HEADER: &str = "x-reasoning-included";

pub struct ResponsesWebsocketConnection {
stream: Arc<Mutex<Option<WsStream>>>,
// TODO (pakrym): is this the right place for timeout?
idle_timeout: Duration,
server_reasoning_included: bool,
models_etag: Option<String>,
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
}

Expand All @@ -48,12 +51,14 @@ impl ResponsesWebsocketConnection {
stream: WsStream,
idle_timeout: Duration,
server_reasoning_included: bool,
models_etag: Option<String>,
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
) -> Self {
Self {
stream: Arc::new(Mutex::new(Some(stream))),
idle_timeout,
server_reasoning_included,
models_etag,
telemetry,
}
}
Expand All @@ -71,12 +76,16 @@ impl ResponsesWebsocketConnection {
let stream = Arc::clone(&self.stream);
let idle_timeout = self.idle_timeout;
let server_reasoning_included = self.server_reasoning_included;
let models_etag = self.models_etag.clone();
let telemetry = self.telemetry.clone();
let request_body = serde_json::to_value(&request).map_err(|err| {
ApiError::Stream(format!("failed to encode websocket request: {err}"))
})?;

tokio::spawn(async move {
if let Some(etag) = models_etag {
let _ = tx_event.send(Ok(ResponseEvent::ModelsEtag(etag))).await;
}
if server_reasoning_included {
let _ = tx_event
.send(Ok(ResponseEvent::ServerReasoningIncluded(true)))
Expand Down Expand Up @@ -136,12 +145,13 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
headers.extend(extra_headers);
add_auth_headers_to_header_map(&self.auth, &mut headers);

let (stream, server_reasoning_included) =
connect_websocket(ws_url, headers, turn_state).await?;
let (stream, server_reasoning_included, models_etag) =
connect_websocket(ws_url, headers, turn_state.clone()).await?;
Ok(ResponsesWebsocketConnection::new(
stream,
self.provider.stream_idle_timeout,
server_reasoning_included,
models_etag,
telemetry,
))
}
Expand All @@ -151,7 +161,7 @@ async fn connect_websocket(
url: Url,
headers: HeaderMap,
turn_state: Option<Arc<OnceLock<String>>>,
) -> Result<(WsStream, bool), ApiError> {
) -> Result<(WsStream, bool, Option<String>), ApiError> {
info!("connecting to websocket: {url}");

let mut request = url
Expand All @@ -177,6 +187,11 @@ async fn connect_websocket(
};

let reasoning_included = response.headers().contains_key(X_REASONING_INCLUDED_HEADER);
let models_etag = response
.headers()
.get(X_MODELS_ETAG_HEADER)
.and_then(|value| value.to_str().ok())
.map(ToString::to_string);
if let Some(turn_state) = turn_state
&& let Some(header_value) = response
.headers()
Expand All @@ -185,7 +200,7 @@ async fn connect_websocket(
{
let _ = turn_state.set(header_value.to_string());
}
Ok((stream, reasoning_included))
Ok((stream, reasoning_included, models_etag))
}

fn map_ws_error(err: WsError, url: &Url) -> ApiError {
Expand Down Expand Up @@ -273,6 +288,12 @@ async fn run_websocket_response_stream(
continue;
}
};
if event.kind() == "codex.rate_limits" {
if let Some(snapshot) = parse_rate_limit_event(&text) {
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
}
continue;
}
match process_responses_event(event) {
Ok(Some(event)) => {
let is_completed = matches!(event, ResponseEvent::Completed { .. });
Expand Down
66 changes: 66 additions & 0 deletions codex-rs/codex-api/src/rate_limits.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use codex_protocol::account::PlanType;
use codex_protocol::protocol::CreditsSnapshot;
use codex_protocol::protocol::RateLimitSnapshot;
use codex_protocol::protocol::RateLimitWindow;
use http::HeaderMap;
use serde::Deserialize;
use std::fmt::Display;

#[derive(Debug)]
Expand Down Expand Up @@ -41,6 +43,70 @@ pub fn parse_rate_limit(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
})
}

#[derive(Debug, Deserialize)]
struct RateLimitEventWindow {
used_percent: f64,
window_minutes: Option<i64>,
reset_at: Option<i64>,
}

#[derive(Debug, Deserialize)]
struct RateLimitEventDetails {
primary: Option<RateLimitEventWindow>,
secondary: Option<RateLimitEventWindow>,
}

#[derive(Debug, Deserialize)]
struct RateLimitEventCredits {
has_credits: bool,
unlimited: bool,
balance: Option<String>,
}

#[derive(Debug, Deserialize)]
struct RateLimitEvent {
#[serde(rename = "type")]
kind: String,
plan_type: Option<PlanType>,
rate_limits: Option<RateLimitEventDetails>,
credits: Option<RateLimitEventCredits>,
}

pub fn parse_rate_limit_event(payload: &str) -> Option<RateLimitSnapshot> {
let event: RateLimitEvent = serde_json::from_str(payload).ok()?;
if event.kind != "codex.rate_limits" {
return None;
}
let (primary, secondary) = if let Some(details) = event.rate_limits.as_ref() {
(
map_event_window(details.primary.as_ref()),
map_event_window(details.secondary.as_ref()),
)
} else {
(None, None)
};
let credits = event.credits.map(|credits| CreditsSnapshot {
has_credits: credits.has_credits,
unlimited: credits.unlimited,
balance: credits.balance,
});
Some(RateLimitSnapshot {
primary,
secondary,
credits,
plan_type: event.plan_type,
})
}

fn map_event_window(window: Option<&RateLimitEventWindow>) -> Option<RateLimitWindow> {
let window = window?;
Some(RateLimitWindow {
used_percent: window.used_percent,
window_minutes: window.window_minutes,
resets_at: window.reset_at,
})
}

/// Parses the bespoke Codex rate-limit headers into a `RateLimitSnapshot`.
pub fn parse_promo_message(headers: &HeaderMap) -> Option<String> {
parse_header_str(headers, "x-codex-promo-message")
Expand Down
6 changes: 6 additions & 0 deletions codex-rs/codex-api/src/sse/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,12 @@ pub struct ResponsesStreamEvent {
content_index: Option<i64>,
}

impl ResponsesStreamEvent {
pub fn kind(&self) -> &str {
&self.kind
}
}

#[derive(Debug)]
pub enum ResponsesEventError {
Api(ApiError),
Expand Down
86 changes: 86 additions & 0 deletions codex-rs/core/tests/suite/client_websockets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use codex_otel::OtelManager;
use codex_otel::metrics::MetricsClient;
use codex_otel::metrics::MetricsConfig;
use codex_protocol::ThreadId;
use codex_protocol::account::PlanType;
use codex_protocol::config_types::ReasoningSummary;
use core_test_support::load_default_config_for_test;
use core_test_support::responses::WebSocketConnectionConfig;
Expand All @@ -29,6 +30,7 @@ use core_test_support::skip_if_no_network;
use futures::StreamExt;
use opentelemetry_sdk::metrics::InMemoryMetricExporter;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::sync::Arc;
use std::time::Duration;
use tempfile::TempDir;
Expand Down Expand Up @@ -136,6 +138,90 @@ async fn responses_websocket_emits_reasoning_included_event() {
server.shutdown().await;
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_emits_rate_limit_events() {
skip_if_no_network!();

let rate_limit_event = json!({
"type": "codex.rate_limits",
"plan_type": "plus",
"rate_limits": {
"allowed": true,
"limit_reached": false,
"primary": {
"used_percent": 42,
"window_minutes": 60,
"reset_at": 1700000000
},
"secondary": null
},
"code_review_rate_limits": null,
"credits": {
"has_credits": true,
"unlimited": false,
"balance": "123"
},
"promo": null
});

let server = start_websocket_server_with_headers(vec![WebSocketConnectionConfig {
requests: vec![vec![
rate_limit_event,
ev_response_created("resp-1"),
ev_completed("resp-1"),
]],
response_headers: vec![
("X-Models-Etag".to_string(), "etag-123".to_string()),
("X-Reasoning-Included".to_string(), "true".to_string()),
],
}])
.await;

let harness = websocket_harness(&server).await;
let mut session = harness.client.new_session(None);
let prompt = prompt_with_input(vec![message_item("hello")]);

let mut stream = session
.stream(&prompt)
.await
.expect("websocket stream failed");

let mut saw_rate_limits = None;
let mut saw_models_etag = None;
let mut saw_reasoning_included = false;

while let Some(event) = stream.next().await {
match event.expect("event") {
ResponseEvent::RateLimits(snapshot) => {
saw_rate_limits = Some(snapshot);
}
ResponseEvent::ModelsEtag(etag) => {
saw_models_etag = Some(etag);
}
ResponseEvent::ServerReasoningIncluded(true) => {
saw_reasoning_included = true;
}
ResponseEvent::Completed { .. } => break,
_ => {}
}
}

let rate_limits = saw_rate_limits.expect("missing rate limits");
let primary = rate_limits.primary.expect("missing primary window");
assert_eq!(primary.used_percent, 42.0);
assert_eq!(primary.window_minutes, Some(60));
assert_eq!(primary.resets_at, Some(1_700_000_000));
assert_eq!(rate_limits.plan_type, Some(PlanType::Plus));
let credits = rate_limits.credits.expect("missing credits");
assert!(credits.has_credits);
assert!(!credits.unlimited);
assert_eq!(credits.balance.as_deref(), Some("123"));
assert_eq!(saw_models_etag.as_deref(), Some("etag-123"));
assert!(saw_reasoning_included);

server.shutdown().await;
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_appends_on_prefix() {
skip_if_no_network!();
Expand Down
Loading