Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions codex-rs/codex-api/src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use codex_client::Request;
use http::HeaderMap;
use http::HeaderValue;

/// Provides bearer and account identity information for API requests.
///
Expand All @@ -12,16 +14,20 @@ pub trait AuthProvider: Send + Sync {
}
}

pub(crate) fn add_auth_headers<A: AuthProvider>(auth: &A, mut req: Request) -> Request {
pub(crate) fn add_auth_headers_to_header_map<A: AuthProvider>(auth: &A, headers: &mut HeaderMap) {
if let Some(token) = auth.bearer_token()
&& let Ok(header) = format!("Bearer {token}").parse()
&& let Ok(header) = HeaderValue::from_str(&format!("Bearer {token}"))
{
let _ = req.headers.insert(http::header::AUTHORIZATION, header);
let _ = headers.insert(http::header::AUTHORIZATION, header);
}
if let Some(account_id) = auth.account_id()
&& let Ok(header) = account_id.parse()
&& let Ok(header) = HeaderValue::from_str(&account_id)
{
let _ = req.headers.insert("ChatGPT-Account-ID", header);
let _ = headers.insert("ChatGPT-Account-ID", header);
}
}

pub(crate) fn add_auth_headers<A: AuthProvider>(auth: &A, mut req: Request) -> Request {
add_auth_headers_to_header_map(auth, &mut req.headers);
req
}
67 changes: 16 additions & 51 deletions codex-rs/codex-api/src/endpoint/compact.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use crate::auth::AuthProvider;
use crate::auth::add_auth_headers;
use crate::common::CompactionInput;
use crate::endpoint::session::EndpointSession;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::telemetry::run_with_request_telemetry;
use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use codex_protocol::models::ResponseItem;
Expand All @@ -14,28 +13,23 @@ use serde_json::to_value;
use std::sync::Arc;

pub struct CompactClient<T: HttpTransport, A: AuthProvider> {
transport: T,
provider: Provider,
auth: A,
request_telemetry: Option<Arc<dyn RequestTelemetry>>,
session: EndpointSession<T, A>,
}

impl<T: HttpTransport, A: AuthProvider> CompactClient<T, A> {
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
Self {
transport,
provider,
auth,
request_telemetry: None,
session: EndpointSession::new(transport, provider, auth),
}
}

pub fn with_telemetry(mut self, request: Option<Arc<dyn RequestTelemetry>>) -> Self {
self.request_telemetry = request;
self
pub fn with_telemetry(self, request: Option<Arc<dyn RequestTelemetry>>) -> Self {
Self {
session: self.session.with_request_telemetry(request),
}
}

fn path(&self) -> &'static str {
fn path() -> &'static str {
"responses/compact"
}

Expand All @@ -44,21 +38,10 @@ impl<T: HttpTransport, A: AuthProvider> CompactClient<T, A> {
body: serde_json::Value,
extra_headers: HeaderMap,
) -> Result<Vec<ResponseItem>, ApiError> {
let path = self.path();
let builder = || {
let mut req = self.provider.build_request(Method::POST, path);
req.headers.extend(extra_headers.clone());
req.body = Some(body.clone());
add_auth_headers(&self.auth, req)
};

let resp = run_with_request_telemetry(
self.provider.retry.to_policy(),
self.request_telemetry.clone(),
builder,
|req| self.transport.execute(req),
)
.await?;
let resp = self
.session
.execute(Method::POST, Self::path(), extra_headers, Some(body))
.await?;
let parsed: CompactHistoryResponse =
serde_json::from_slice(&resp.body).map_err(|e| ApiError::Stream(e.to_string()))?;
Ok(parsed.output)
Expand All @@ -83,14 +66,11 @@ struct CompactHistoryResponse {
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::RetryConfig;
use async_trait::async_trait;
use codex_client::Request;
use codex_client::Response;
use codex_client::StreamResponse;
use codex_client::TransportError;
use http::HeaderMap;
use std::time::Duration;

#[derive(Clone, Default)]
struct DummyTransport;
Expand All @@ -115,26 +95,11 @@ mod tests {
}
}

fn provider() -> Provider {
Provider {
name: "test".to_string(),
base_url: "https://example.com/v1".to_string(),
query_params: None,
headers: HeaderMap::new(),
retry: RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(1),
retry_429: false,
retry_5xx: true,
retry_transport: true,
},
stream_idle_timeout: Duration::from_secs(1),
}
}

#[test]
fn path_is_responses_compact() {
let client = CompactClient::new(DummyTransport, provider(), DummyAuth);
assert_eq!(client.path(), "responses/compact");
assert_eq!(
CompactClient::<DummyTransport, DummyAuth>::path(),
"responses/compact"
);
}
}
2 changes: 1 addition & 1 deletion codex-rs/codex-api/src/endpoint/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ pub mod compact;
pub mod models;
pub mod responses;
pub mod responses_websocket;
mod streaming;
mod session;
50 changes: 19 additions & 31 deletions codex-rs/codex-api/src/endpoint/models.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use crate::auth::AuthProvider;
use crate::auth::add_auth_headers;
use crate::endpoint::session::EndpointSession;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::telemetry::run_with_request_telemetry;
use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use codex_protocol::openai_models::ModelInfo;
Expand All @@ -13,53 +12,42 @@ use http::header::ETAG;
use std::sync::Arc;

pub struct ModelsClient<T: HttpTransport, A: AuthProvider> {
transport: T,
provider: Provider,
auth: A,
request_telemetry: Option<Arc<dyn RequestTelemetry>>,
session: EndpointSession<T, A>,
}

impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
Self {
transport,
provider,
auth,
request_telemetry: None,
session: EndpointSession::new(transport, provider, auth),
}
}

pub fn with_telemetry(mut self, request: Option<Arc<dyn RequestTelemetry>>) -> Self {
self.request_telemetry = request;
self
pub fn with_telemetry(self, request: Option<Arc<dyn RequestTelemetry>>) -> Self {
Self {
session: self.session.with_request_telemetry(request),
}
}

fn path(&self) -> &'static str {
fn path() -> &'static str {
"models"
}

fn append_client_version_query(req: &mut codex_client::Request, client_version: &str) {
let separator = if req.url.contains('?') { '&' } else { '?' };
req.url = format!("{}{}client_version={client_version}", req.url, separator);
}

pub async fn list_models(
&self,
client_version: &str,
extra_headers: HeaderMap,
) -> Result<(Vec<ModelInfo>, Option<String>), ApiError> {
let builder = || {
let mut req = self.provider.build_request(Method::GET, self.path());
req.headers.extend(extra_headers.clone());

let separator = if req.url.contains('?') { '&' } else { '?' };
req.url = format!("{}{}client_version={client_version}", req.url, separator);

add_auth_headers(&self.auth, req)
};

let resp = run_with_request_telemetry(
self.provider.retry.to_policy(),
self.request_telemetry.clone(),
builder,
|req| self.transport.execute(req),
)
.await?;
let resp = self
.session
.execute_with(Method::GET, Self::path(), extra_headers, None, |req| {
Self::append_client_version_query(req, client_version);
})
.await?;

let header_etag = resp
.headers
Expand Down
48 changes: 33 additions & 15 deletions codex-rs/codex-api/src/endpoint/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::common::Prompt as ApiPrompt;
use crate::common::Reasoning;
use crate::common::ResponseStream;
use crate::common::TextControls;
use crate::endpoint::streaming::StreamingClient;
use crate::endpoint::session::EndpointSession;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::requests::ResponsesRequest;
Expand All @@ -16,13 +16,16 @@ use codex_client::RequestCompression;
use codex_client::RequestTelemetry;
use codex_protocol::protocol::SessionSource;
use http::HeaderMap;
use http::HeaderValue;
use http::Method;
use serde_json::Value;
use std::sync::Arc;
use std::sync::OnceLock;
use tracing::instrument;

pub struct ResponsesClient<T: HttpTransport, A: AuthProvider> {
streaming: StreamingClient<T, A>,
session: EndpointSession<T, A>,
sse_telemetry: Option<Arc<dyn SseTelemetry>>,
}

#[derive(Default)]
Expand All @@ -42,7 +45,8 @@ pub struct ResponsesOptions {
impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
Self {
streaming: StreamingClient::new(transport, provider, auth),
session: EndpointSession::new(transport, provider, auth),
sse_telemetry: None,
}
}

Expand All @@ -52,7 +56,8 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
sse: Option<Arc<dyn SseTelemetry>>,
) -> Self {
Self {
streaming: self.streaming.with_telemetry(request, sse),
session: self.session.with_request_telemetry(request),
sse_telemetry: sse,
}
}

Expand Down Expand Up @@ -102,12 +107,12 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
.store_override(store_override)
.extra_headers(extra_headers)
.compression(compression)
.build(self.streaming.provider())?;
.build(self.session.provider())?;

self.stream_request(request, turn_state).await
}

fn path(&self) -> &'static str {
fn path() -> &'static str {
"responses"
}

Expand All @@ -118,20 +123,33 @@ impl<T: HttpTransport, A: AuthProvider> ResponsesClient<T, A> {
compression: Compression,
turn_state: Option<Arc<OnceLock<String>>>,
) -> Result<ResponseStream, ApiError> {
let compression = match compression {
let request_compression = match compression {
Compression::None => RequestCompression::None,
Compression::Zstd => RequestCompression::Zstd,
};

self.streaming
.stream(
self.path(),
body,
let stream_response = self
.session
.stream_with(
Method::POST,
Self::path(),
extra_headers,
compression,
spawn_response_stream,
turn_state,
Some(body),
|req| {
req.headers.insert(
http::header::ACCEPT,
HeaderValue::from_static("text/event-stream"),
);
req.compression = request_compression;
},
)
.await
.await?;

Ok(spawn_response_stream(
stream_response,
self.session.provider().stream_idle_timeout,
self.sse_telemetry.clone(),
turn_state,
))
}
}
18 changes: 2 additions & 16 deletions codex-rs/codex-api/src/endpoint/responses_websocket.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::auth::AuthProvider;
use crate::auth::add_auth_headers_to_header_map;
use crate::common::ResponseEvent;
use crate::common::ResponseStream;
use crate::common::ResponsesWsRequest;
Expand All @@ -11,7 +12,6 @@ use codex_client::TransportError;
use futures::SinkExt;
use futures::StreamExt;
use http::HeaderMap;
use http::HeaderValue;
use serde_json::Value;
use std::sync::Arc;
use std::sync::OnceLock;
Expand Down Expand Up @@ -134,7 +134,7 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {

let mut headers = self.provider.headers.clone();
headers.extend(extra_headers);
apply_auth_headers(&mut headers, &self.auth);
add_auth_headers_to_header_map(&self.auth, &mut headers);

let (stream, server_reasoning_included) =
connect_websocket(ws_url, headers, turn_state).await?;
Expand All @@ -147,20 +147,6 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
}
}

// TODO (pakrym): share with /auth
fn apply_auth_headers(headers: &mut HeaderMap, auth: &impl AuthProvider) {
if let Some(token) = auth.bearer_token()
&& let Ok(header) = HeaderValue::from_str(&format!("Bearer {token}"))
{
let _ = headers.insert(http::header::AUTHORIZATION, header);
}
if let Some(account_id) = auth.account_id()
&& let Ok(header) = HeaderValue::from_str(&account_id)
{
let _ = headers.insert("ChatGPT-Account-ID", header);
}
}

async fn connect_websocket(
url: Url,
headers: HeaderMap,
Expand Down
Loading
Loading