From 9ca9c5f88cfe90be935dd4db814c71a7aebdfec3 Mon Sep 17 00:00:00 2001 From: jif-oai Date: Wed, 4 Feb 2026 14:55:28 +0000 Subject: [PATCH] feat: add phase 1 mem client --- codex-rs/codex-api/README.md | 9 + codex-rs/codex-api/src/common.rs | 28 ++ codex-rs/codex-api/src/endpoint/memories.rs | 108 ++++++++ codex-rs/codex-api/src/endpoint/mod.rs | 1 + codex-rs/codex-api/src/lib.rs | 5 + codex-rs/core/src/client.rs | 58 +++- codex-rs/core/src/lib.rs | 1 + codex-rs/core/src/memory_trace.rs | 292 ++++++++++++++++++++ 8 files changed, 498 insertions(+), 4 deletions(-) create mode 100644 codex-rs/codex-api/src/endpoint/memories.rs create mode 100644 codex-rs/core/src/memory_trace.rs diff --git a/codex-rs/codex-api/README.md b/codex-rs/codex-api/README.md index c1f7d230c0e..0570cf7570e 100644 --- a/codex-rs/codex-api/README.md +++ b/codex-rs/codex-api/README.md @@ -29,4 +29,13 @@ The public interface of this crate is intentionally small and uniform: - Output: `Vec`. - `CompactClient::compact_input(&CompactionInput, extra_headers)` wraps the JSON encoding and retry/telemetry wiring. +- **Memory trace summarize endpoint** + - Input: `MemoryTraceSummarizeInput` (re-exported as `codex_api::MemoryTraceSummarizeInput`): + - `model: String`. + - `traces: Vec`. + - `MemoryTrace` includes `id`, `metadata.source_path`, and normalized `items`. + - `reasoning: Option`. + - Output: `Vec`. + - `MemoriesClient::trace_summarize_input(&MemoryTraceSummarizeInput, extra_headers)` wraps JSON encoding and retry/telemetry wiring. + All HTTP details (URLs, headers, retry/backoff policies, SSE framing) are encapsulated in `codex-api` and `codex-client`. Callers construct prompts/inputs using protocol types and work with typed streams of `ResponseEvent` or compacted `ResponseItem` values. diff --git a/codex-rs/codex-api/src/common.rs b/codex-rs/codex-api/src/common.rs index a9127644f14..f1a996f12c7 100644 --- a/codex-rs/codex-api/src/common.rs +++ b/codex-rs/codex-api/src/common.rs @@ -6,6 +6,7 @@ use codex_protocol::openai_models::ReasoningEffort as ReasoningEffortConfig; use codex_protocol::protocol::RateLimitSnapshot; use codex_protocol::protocol::TokenUsage; use futures::Stream; +use serde::Deserialize; use serde::Serialize; use serde_json::Value; use std::pin::Pin; @@ -37,6 +38,33 @@ pub struct CompactionInput<'a> { pub instructions: &'a str, } +/// Canonical input payload for the memory trace summarize endpoint. +#[derive(Debug, Clone, Serialize)] +pub struct MemoryTraceSummarizeInput { + pub model: String, + pub traces: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, +} + +#[derive(Debug, Clone, Serialize)] +pub struct MemoryTrace { + pub id: String, + pub metadata: MemoryTraceMetadata, + pub items: Vec, +} + +#[derive(Debug, Clone, Serialize)] +pub struct MemoryTraceMetadata { + pub source_path: String, +} + +#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +pub struct MemoryTraceSummaryOutput { + pub trace_summary: String, + pub memory_summary: String, +} + #[derive(Debug)] pub enum ResponseEvent { Created, diff --git a/codex-rs/codex-api/src/endpoint/memories.rs b/codex-rs/codex-api/src/endpoint/memories.rs new file mode 100644 index 00000000000..c8f35d7e162 --- /dev/null +++ b/codex-rs/codex-api/src/endpoint/memories.rs @@ -0,0 +1,108 @@ +use crate::auth::AuthProvider; +use crate::common::MemoryTraceSummarizeInput; +use crate::common::MemoryTraceSummaryOutput; +use crate::endpoint::session::EndpointSession; +use crate::error::ApiError; +use crate::provider::Provider; +use codex_client::HttpTransport; +use codex_client::RequestTelemetry; +use http::HeaderMap; +use http::Method; +use serde::Deserialize; +use serde_json::to_value; +use std::sync::Arc; + +pub struct MemoriesClient { + session: EndpointSession, +} + +impl MemoriesClient { + pub fn new(transport: T, provider: Provider, auth: A) -> Self { + Self { + session: EndpointSession::new(transport, provider, auth), + } + } + + pub fn with_telemetry(self, request: Option>) -> Self { + Self { + session: self.session.with_request_telemetry(request), + } + } + + fn path() -> &'static str { + "memories/trace_summarize" + } + + pub async fn trace_summarize( + &self, + body: serde_json::Value, + extra_headers: HeaderMap, + ) -> Result, ApiError> { + let resp = self + .session + .execute(Method::POST, Self::path(), extra_headers, Some(body)) + .await?; + let parsed: TraceSummarizeResponse = + serde_json::from_slice(&resp.body).map_err(|e| ApiError::Stream(e.to_string()))?; + Ok(parsed.output) + } + + pub async fn trace_summarize_input( + &self, + input: &MemoryTraceSummarizeInput, + extra_headers: HeaderMap, + ) -> Result, ApiError> { + let body = to_value(input).map_err(|e| { + ApiError::Stream(format!( + "failed to encode memory trace summarize input: {e}" + )) + })?; + self.trace_summarize(body, extra_headers).await + } +} + +#[derive(Debug, Deserialize)] +struct TraceSummarizeResponse { + output: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + use async_trait::async_trait; + use codex_client::Request; + use codex_client::Response; + use codex_client::StreamResponse; + use codex_client::TransportError; + + #[derive(Clone, Default)] + struct DummyTransport; + + #[async_trait] + impl HttpTransport for DummyTransport { + async fn execute(&self, _req: Request) -> Result { + Err(TransportError::Build("execute should not run".to_string())) + } + + async fn stream(&self, _req: Request) -> Result { + Err(TransportError::Build("stream should not run".to_string())) + } + } + + #[derive(Clone, Default)] + struct DummyAuth; + + impl AuthProvider for DummyAuth { + fn bearer_token(&self) -> Option { + None + } + } + + #[test] + fn path_is_memories_trace_summarize() { + assert_eq!( + MemoriesClient::::path(), + "memories/trace_summarize" + ); + } +} diff --git a/codex-rs/codex-api/src/endpoint/mod.rs b/codex-rs/codex-api/src/endpoint/mod.rs index 23579ffcf16..0dede138e81 100644 --- a/codex-rs/codex-api/src/endpoint/mod.rs +++ b/codex-rs/codex-api/src/endpoint/mod.rs @@ -1,5 +1,6 @@ pub mod aggregate; pub mod compact; +pub mod memories; pub mod models; pub mod responses; pub mod responses_websocket; diff --git a/codex-rs/codex-api/src/lib.rs b/codex-rs/codex-api/src/lib.rs index b0c70084d41..70652d2d78b 100644 --- a/codex-rs/codex-api/src/lib.rs +++ b/codex-rs/codex-api/src/lib.rs @@ -15,6 +15,10 @@ pub use codex_client::TransportError; pub use crate::auth::AuthProvider; pub use crate::common::CompactionInput; +pub use crate::common::MemoryTrace; +pub use crate::common::MemoryTraceMetadata; +pub use crate::common::MemoryTraceSummarizeInput; +pub use crate::common::MemoryTraceSummaryOutput; pub use crate::common::Prompt; pub use crate::common::ResponseAppendWsRequest; pub use crate::common::ResponseCreateWsRequest; @@ -24,6 +28,7 @@ pub use crate::common::ResponsesApiRequest; pub use crate::common::create_text_param_for_request; pub use crate::endpoint::aggregate::AggregateStreamExt; pub use crate::endpoint::compact::CompactClient; +pub use crate::endpoint::memories::MemoriesClient; pub use crate::endpoint::models::ModelsClient; pub use crate::endpoint::responses::ResponsesClient; pub use crate::endpoint::responses::ResponsesOptions; diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 1d0dc7fb658..4833eee1063 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -7,6 +7,10 @@ use crate::api_bridge::map_api_error; use crate::auth::UnauthorizedRecovery; use codex_api::CompactClient as ApiCompactClient; use codex_api::CompactionInput as ApiCompactionInput; +use codex_api::MemoriesClient as ApiMemoriesClient; +use codex_api::MemoryTrace as ApiMemoryTrace; +use codex_api::MemoryTraceSummarizeInput as ApiMemoryTraceSummarizeInput; +use codex_api::MemoryTraceSummaryOutput as ApiMemoryTraceSummaryOutput; use codex_api::Prompt as ApiPrompt; use codex_api::RequestTelemetry; use codex_api::ReqwestTransport; @@ -238,6 +242,55 @@ impl ModelClient { instructions: &instructions, }; + let extra_headers = self.build_subagent_headers(); + client + .compact_input(&payload, extra_headers) + .await + .map_err(map_api_error) + } + + /// Builds memory summaries for each provided normalized trace. + /// + /// This is a unary call (no streaming) to `/v1/memories/trace_summarize`. + pub async fn summarize_memory_traces( + &self, + traces: Vec, + ) -> Result> { + if traces.is_empty() { + return Ok(Vec::new()); + } + + let auth_manager = self.state.auth_manager.clone(); + let auth = match auth_manager.as_ref() { + Some(manager) => manager.auth().await, + None => None, + }; + let api_provider = self + .state + .provider + .to_api_provider(auth.as_ref().map(CodexAuth::internal_auth_mode))?; + let api_auth = auth_provider_from_auth(auth, &self.state.provider)?; + let transport = ReqwestTransport::new(build_reqwest_client()); + let request_telemetry = self.build_request_telemetry(); + let client = ApiMemoriesClient::new(transport, api_provider, api_auth) + .with_telemetry(Some(request_telemetry)); + + let payload = ApiMemoryTraceSummarizeInput { + model: self.state.model_info.slug.clone(), + traces, + reasoning: self.state.effort.map(|effort| Reasoning { + effort: Some(effort), + summary: None, + }), + }; + + client + .trace_summarize_input(&payload, self.build_subagent_headers()) + .await + .map_err(map_api_error) + } + + fn build_subagent_headers(&self) -> ApiHeaderMap { let mut extra_headers = ApiHeaderMap::new(); if let SessionSource::SubAgent(sub) = &self.state.session_source { let subagent = match sub { @@ -250,10 +303,7 @@ impl ModelClient { extra_headers.insert("x-openai-subagent", val); } } - client - .compact_input(&payload, extra_headers) - .await - .map_err(map_api_error) + extra_headers } } diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 9bd4c872536..36fdbdc9a6a 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -161,4 +161,5 @@ pub use codex_protocol::models::ResponseItem; pub use compact::content_items_to_text; pub use event_mapping::parse_turn_item; pub mod compact; +pub mod memory_trace; pub mod otel_init; diff --git a/codex-rs/core/src/memory_trace.rs b/codex-rs/core/src/memory_trace.rs new file mode 100644 index 00000000000..807ff64df18 --- /dev/null +++ b/codex-rs/core/src/memory_trace.rs @@ -0,0 +1,292 @@ +use std::path::Path; +use std::path::PathBuf; + +use crate::ModelClient; +use crate::error::CodexErr; +use crate::error::Result; +use codex_api::MemoryTrace as ApiMemoryTrace; +use codex_api::MemoryTraceMetadata as ApiMemoryTraceMetadata; +use serde_json::Map; +use serde_json::Value; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BuiltTraceMemory { + pub trace_id: String, + pub source_path: PathBuf, + pub trace_summary: String, + pub memory_summary: String, +} + +struct PreparedTrace { + trace_id: String, + source_path: PathBuf, + payload: ApiMemoryTrace, +} + +/// Loads raw trace files, normalizes trace items, and builds memory summaries. +/// +/// The request/response wiring mirrors the memory trace summarize E2E flow: +/// `/v1/memories/trace_summarize` with one output object per input trace. +pub async fn build_memories_from_trace_files( + client: &ModelClient, + trace_paths: &[PathBuf], +) -> Result> { + if trace_paths.is_empty() { + return Ok(Vec::new()); + } + + let mut prepared = Vec::with_capacity(trace_paths.len()); + for (index, path) in trace_paths.iter().enumerate() { + prepared.push(prepare_trace(index + 1, path).await?); + } + + let traces = prepared.iter().map(|trace| trace.payload.clone()).collect(); + let output = client.summarize_memory_traces(traces).await?; + if output.len() != prepared.len() { + return Err(CodexErr::InvalidRequest(format!( + "unexpected memory summarize output length: expected {}, got {}", + prepared.len(), + output.len() + ))); + } + + Ok(prepared + .into_iter() + .zip(output) + .map(|(trace, summary)| BuiltTraceMemory { + trace_id: trace.trace_id, + source_path: trace.source_path, + trace_summary: summary.trace_summary, + memory_summary: summary.memory_summary, + }) + .collect()) +} + +async fn prepare_trace(index: usize, path: &Path) -> Result { + let text = load_trace_text(path).await?; + let items = load_trace_items(path, &text)?; + let trace_id = build_trace_id(index, path); + let source_path = path.to_path_buf(); + + Ok(PreparedTrace { + trace_id: trace_id.clone(), + source_path: source_path.clone(), + payload: ApiMemoryTrace { + id: trace_id, + metadata: ApiMemoryTraceMetadata { + source_path: source_path.display().to_string(), + }, + items, + }, + }) +} + +async fn load_trace_text(path: &Path) -> Result { + let raw = tokio::fs::read(path).await?; + Ok(decode_trace_bytes(&raw)) +} + +fn decode_trace_bytes(raw: &[u8]) -> String { + if let Some(without_bom) = raw.strip_prefix(&[0xEF, 0xBB, 0xBF]) + && let Ok(text) = String::from_utf8(without_bom.to_vec()) + { + return text; + } + if let Ok(text) = String::from_utf8(raw.to_vec()) { + return text; + } + raw.iter().map(|b| char::from(*b)).collect() +} + +fn load_trace_items(path: &Path, text: &str) -> Result> { + if let Ok(Value::Array(items)) = serde_json::from_str::(text) { + let dict_items = items + .into_iter() + .filter(serde_json::Value::is_object) + .collect::>(); + if dict_items.is_empty() { + return Err(CodexErr::InvalidRequest(format!( + "no object items found in trace file: {}", + path.display() + ))); + } + return normalize_trace_items(dict_items, path); + } + + let mut parsed_items = Vec::new(); + for line in text.lines() { + let line = line.trim(); + if line.is_empty() || (!line.starts_with('{') && !line.starts_with('[')) { + continue; + } + + let Ok(obj) = serde_json::from_str::(line) else { + continue; + }; + + match obj { + Value::Object(_) => parsed_items.push(obj), + Value::Array(inner) => { + parsed_items.extend(inner.into_iter().filter(serde_json::Value::is_object)) + } + _ => {} + } + } + + if parsed_items.is_empty() { + return Err(CodexErr::InvalidRequest(format!( + "no JSON items parsed from trace file: {}", + path.display() + ))); + } + + normalize_trace_items(parsed_items, path) +} + +fn normalize_trace_items(items: Vec, path: &Path) -> Result> { + let mut normalized = Vec::new(); + + for item in items { + let Value::Object(obj) = item else { + continue; + }; + + if let Some(payload) = obj.get("payload") { + if obj.get("type").and_then(Value::as_str) != Some("response_item") { + continue; + } + + match payload { + Value::Object(payload_item) => { + if is_allowed_trace_item(payload_item) { + normalized.push(Value::Object(payload_item.clone())); + } + } + Value::Array(payload_items) => { + for payload_item in payload_items { + if let Value::Object(payload_item) = payload_item + && is_allowed_trace_item(payload_item) + { + normalized.push(Value::Object(payload_item.clone())); + } + } + } + _ => {} + } + continue; + } + + if is_allowed_trace_item(&obj) { + normalized.push(Value::Object(obj)); + } + } + + if normalized.is_empty() { + return Err(CodexErr::InvalidRequest(format!( + "no valid trace items after normalization: {}", + path.display() + ))); + } + Ok(normalized) +} + +fn is_allowed_trace_item(item: &Map) -> bool { + let Some(item_type) = item.get("type").and_then(Value::as_str) else { + return false; + }; + + if item_type == "message" { + return matches!( + item.get("role").and_then(Value::as_str), + Some("assistant" | "system" | "developer" | "user") + ); + } + + true +} + +fn build_trace_id(index: usize, path: &Path) -> String { + let stem = path + .file_stem() + .map(|stem| stem.to_string_lossy().into_owned()) + .filter(|stem| !stem.is_empty()) + .unwrap_or_else(|| "trace".to_string()); + format!("trace_{index}_{stem}") +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use tempfile::tempdir; + + #[test] + fn normalize_trace_items_handles_payload_wrapper_and_message_role_filtering() { + let items = vec![ + serde_json::json!({ + "type": "response_item", + "payload": {"type": "message", "role": "assistant", "content": []} + }), + serde_json::json!({ + "type": "response_item", + "payload": [ + {"type": "message", "role": "user", "content": []}, + {"type": "message", "role": "tool", "content": []}, + {"type": "function_call", "name": "shell", "arguments": "{}", "call_id": "c1"} + ] + }), + serde_json::json!({ + "type": "not_response_item", + "payload": {"type": "message", "role": "assistant", "content": []} + }), + serde_json::json!({ + "type": "message", + "role": "developer", + "content": [] + }), + ]; + + let normalized = normalize_trace_items(items, Path::new("trace.json")).expect("normalize"); + let expected = vec![ + serde_json::json!({"type": "message", "role": "assistant", "content": []}), + serde_json::json!({"type": "message", "role": "user", "content": []}), + serde_json::json!({"type": "function_call", "name": "shell", "arguments": "{}", "call_id": "c1"}), + serde_json::json!({"type": "message", "role": "developer", "content": []}), + ]; + assert_eq!(normalized, expected); + } + + #[test] + fn load_trace_items_supports_jsonl_arrays_and_objects() { + let text = r#" +{"type":"response_item","payload":{"type":"message","role":"assistant","content":[]}} +[{"type":"message","role":"user","content":[]},{"type":"message","role":"tool","content":[]}] +"#; + let loaded = load_trace_items(Path::new("trace.jsonl"), text).expect("load"); + let expected = vec![ + serde_json::json!({"type":"message","role":"assistant","content":[]}), + serde_json::json!({"type":"message","role":"user","content":[]}), + ]; + assert_eq!(loaded, expected); + } + + #[tokio::test] + async fn load_trace_text_decodes_utf8_sig() { + let dir = tempdir().expect("tempdir"); + let path = dir.path().join("trace.json"); + tokio::fs::write( + &path, + [ + 0xEF, 0xBB, 0xBF, b'[', b'{', b'"', b't', b'y', b'p', b'e', b'"', b':', b'"', b'm', + b'e', b's', b's', b'a', b'g', b'e', b'"', b',', b'"', b'r', b'o', b'l', b'e', b'"', + b':', b'"', b'u', b's', b'e', b'r', b'"', b',', b'"', b'c', b'o', b'n', b't', b'e', + b'n', b't', b'"', b':', b'[', b']', b'}', b']', + ], + ) + .await + .expect("write"); + + let text = load_trace_text(&path).await.expect("decode"); + assert!(text.starts_with('[')); + } +}