diff --git a/codex-rs/core/src/rollout/list.rs b/codex-rs/core/src/rollout/list.rs index 16bd5b16edd..62c5e44805e 100644 --- a/codex-rs/core/src/rollout/list.rs +++ b/codex-rs/core/src/rollout/list.rs @@ -261,6 +261,14 @@ impl<'de> serde::Deserialize<'de> for Cursor { } } +impl From for Cursor { + fn from(anchor: codex_state::Anchor) -> Self { + let ts = OffsetDateTime::from_unix_timestamp(anchor.ts.timestamp()) + .unwrap_or(OffsetDateTime::UNIX_EPOCH); + Self::new(ts, anchor.id) + } +} + /// Retrieve recorded thread file paths with token pagination. The returned `next_cursor` /// can be supplied on the next call to resume after the last returned item, resilient to /// concurrent new sessions being appended. Ordering is stable by the requested sort key @@ -989,7 +997,6 @@ async fn read_head_summary(path: &Path, head_limit: usize) -> io::Result, default_provider: &str, ) -> std::io::Result { - let stage = "list_threads"; - let page = get_threads( + Self::list_threads_with_db_fallback( codex_home, page_size, cursor, @@ -129,39 +131,36 @@ impl RolloutRecorder { allowed_sources, model_providers, default_provider, + false, ) - .await?; + .await + } - // TODO(jif): drop after sqlite migration phase 1 - let state_db_ctx = state_db::open_if_present(codex_home, default_provider).await; - if let Some(db_ids) = state_db::list_thread_ids_db( - state_db_ctx.as_deref(), + /// List archived threads (rollout files) under the archived sessions directory. + pub async fn list_archived_threads( + codex_home: &Path, + page_size: usize, + cursor: Option<&Cursor>, + sort_key: ThreadSortKey, + allowed_sources: &[SessionSource], + model_providers: Option<&[String]>, + default_provider: &str, + ) -> std::io::Result { + Self::list_threads_with_db_fallback( codex_home, page_size, cursor, sort_key, allowed_sources, model_providers, - false, - stage, + default_provider, + true, ) .await - { - if page.items.len() != db_ids.len() { - state_db::record_discrepancy(stage, "bad_len"); - return Ok(page); - } - for (id, item) in db_ids.iter().zip(page.items.iter()) { - if !item.path.display().to_string().contains(&id.to_string()) { - state_db::record_discrepancy(stage, "bad_id"); - } - } - } - Ok(page) } - /// List archived threads (rollout files) under the archived sessions directory. - pub async fn list_archived_threads( + #[allow(clippy::too_many_arguments)] + async fn list_threads_with_db_fallback( codex_home: &Path, page_size: usize, cursor: Option<&Cursor>, @@ -169,49 +168,53 @@ impl RolloutRecorder { allowed_sources: &[SessionSource], model_providers: Option<&[String]>, default_provider: &str, + archived: bool, ) -> std::io::Result { - let stage = "list_archived_threads"; - let root = codex_home.join(ARCHIVED_SESSIONS_SUBDIR); - let page = get_threads_in_root( - root, + let state_db_ctx = state_db::open_if_present(codex_home, default_provider).await; + if let Some(db_page) = state_db::list_threads_db( + state_db_ctx.as_deref(), + codex_home, page_size, cursor, sort_key, - ThreadListConfig { - allowed_sources, - model_providers, - default_provider, - layout: ThreadListLayout::Flat, - }, + allowed_sources, + model_providers, + archived, ) - .await?; + .await + { + let mut page: ThreadsPage = db_page.into(); + populate_thread_heads(page.items.as_mut_slice()).await; + return Ok(page); + } - // TODO(jif): drop after sqlite migration phase 1 - let state_db_ctx = state_db::open_if_present(codex_home, default_provider).await; - if let Some(db_ids) = state_db::list_thread_ids_db( - state_db_ctx.as_deref(), + if archived { + let root = codex_home.join(ARCHIVED_SESSIONS_SUBDIR); + return get_threads_in_root( + root, + page_size, + cursor, + sort_key, + ThreadListConfig { + allowed_sources, + model_providers, + default_provider, + layout: ThreadListLayout::Flat, + }, + ) + .await; + } + + get_threads( codex_home, page_size, cursor, sort_key, allowed_sources, model_providers, - true, - stage, + default_provider, ) .await - { - if page.items.len() != db_ids.len() { - state_db::record_discrepancy(stage, "bad_len"); - return Ok(page); - } - for (id, item) in db_ids.iter().zip(page.items.iter()) { - if !item.path.display().to_string().contains(&id.to_string()) { - state_db::record_discrepancy(stage, "bad_id"); - } - } - } - Ok(page) } /// Find the newest recorded thread path, optionally filtering to a matching cwd. @@ -645,6 +648,41 @@ impl JsonlWriter { } } +impl From for ThreadsPage { + fn from(db_page: codex_state::ThreadsPage) -> Self { + let items = db_page + .items + .into_iter() + .map(|item| ThreadItem { + path: item.rollout_path, + head: Vec::new(), + created_at: Some(item.created_at.to_rfc3339_opts(SecondsFormat::Secs, true)), + updated_at: Some(item.updated_at.to_rfc3339_opts(SecondsFormat::Secs, true)), + }) + .collect(); + Self { + items, + next_cursor: db_page.next_anchor.map(Into::into), + num_scanned_files: db_page.num_scanned_rows, + reached_scan_cap: false, + } + } +} + +async fn populate_thread_heads(items: &mut [ThreadItem]) { + for item in items { + item.head = read_head_for_summary(item.path.as_path()) + .await + .unwrap_or_else(|err| { + warn!( + "failed to read rollout head from state db path: {} ({err})", + item.path.display() + ); + Vec::new() + }); + } +} + fn select_resume_path(page: &ThreadsPage, filter_cwd: Option<&Path>) -> Option { match filter_cwd { Some(cwd) => page.items.iter().find_map(|item| { diff --git a/codex-rs/core/src/rollout/tests.rs b/codex-rs/core/src/rollout/tests.rs index ee750b126d2..e42e1861576 100644 --- a/codex-rs/core/src/rollout/tests.rs +++ b/codex-rs/core/src/rollout/tests.rs @@ -7,6 +7,7 @@ use std::fs::{self}; use std::io::Write; use std::path::Path; +use chrono::TimeZone; use pretty_assertions::assert_eq; use tempfile::TempDir; use time::Duration; @@ -22,6 +23,7 @@ use crate::rollout::list::ThreadItem; use crate::rollout::list::ThreadSortKey; use crate::rollout::list::ThreadsPage; use crate::rollout::list::get_threads; +use crate::rollout::recorder::RolloutRecorder; use crate::rollout::rollout_date_parts; use anyhow::Result; use codex_protocol::ThreadId; @@ -45,6 +47,191 @@ fn provider_vec(providers: &[&str]) -> Vec { .collect() } +async fn insert_state_db_thread( + home: &Path, + thread_id: ThreadId, + rollout_path: &Path, + archived: bool, +) { + let runtime = + codex_state::StateRuntime::init(home.to_path_buf(), TEST_PROVIDER.to_string(), None) + .await + .expect("state db should initialize"); + let created_at = chrono::Utc + .with_ymd_and_hms(2025, 1, 3, 12, 0, 0) + .single() + .expect("valid datetime"); + let mut builder = codex_state::ThreadMetadataBuilder::new( + thread_id, + rollout_path.to_path_buf(), + created_at, + SessionSource::Cli, + ); + builder.model_provider = Some(TEST_PROVIDER.to_string()); + builder.cwd = home.to_path_buf(); + if archived { + builder.archived_at = Some(created_at); + } + let mut metadata = builder.build(TEST_PROVIDER); + metadata.has_user_event = true; + runtime + .upsert_thread(&metadata) + .await + .expect("state db upsert should succeed"); +} + +#[tokio::test] +async fn list_threads_prefers_state_db_when_available() { + let temp = TempDir::new().unwrap(); + let home = temp.path(); + let fs_uuid = Uuid::from_u128(101); + write_session_file( + home, + "2025-01-03T13-00-00", + fs_uuid, + 1, + Some(SessionSource::Cli), + ) + .unwrap(); + + let db_uuid = Uuid::from_u128(102); + let db_thread_id = ThreadId::from_string(&db_uuid.to_string()).expect("valid thread id"); + let db_rollout_path = home.join(format!( + "sessions/2025/01/03/rollout-2025-01-03T12-00-00-{db_uuid}.jsonl" + )); + insert_state_db_thread(home, db_thread_id, db_rollout_path.as_path(), false).await; + + let page = RolloutRecorder::list_threads( + home, + 10, + None, + ThreadSortKey::CreatedAt, + NO_SOURCE_FILTER, + None, + TEST_PROVIDER, + ) + .await + .expect("thread listing should succeed"); + + assert_eq!(page.items.len(), 1); + assert_eq!(page.items[0].path, db_rollout_path); +} + +#[tokio::test] +async fn list_archived_threads_prefers_state_db_when_available() { + let temp = TempDir::new().unwrap(); + let home = temp.path(); + let archived_root = home.join("archived_sessions"); + fs::create_dir_all(&archived_root).unwrap(); + let fs_uuid = Uuid::from_u128(201); + let fs_path = archived_root.join(format!("rollout-2025-01-03T13-00-00-{fs_uuid}.jsonl")); + fs::write(&fs_path, "{\"type\":\"session_meta\",\"payload\":{}}\n").unwrap(); + + let db_uuid = Uuid::from_u128(202); + let db_thread_id = ThreadId::from_string(&db_uuid.to_string()).expect("valid thread id"); + let db_rollout_path = + archived_root.join(format!("rollout-2025-01-03T12-00-00-{db_uuid}.jsonl")); + insert_state_db_thread(home, db_thread_id, db_rollout_path.as_path(), true).await; + + let page = RolloutRecorder::list_archived_threads( + home, + 10, + None, + ThreadSortKey::CreatedAt, + NO_SOURCE_FILTER, + None, + TEST_PROVIDER, + ) + .await + .expect("archived thread listing should succeed"); + + assert_eq!(page.items.len(), 1); + assert_eq!(page.items[0].path, db_rollout_path); +} + +#[tokio::test] +async fn list_threads_db_excludes_archived_entries() { + let temp = TempDir::new().unwrap(); + let home = temp.path(); + let sessions_root = home.join("sessions/2025/01/03"); + let archived_root = home.join("archived_sessions"); + fs::create_dir_all(&sessions_root).unwrap(); + fs::create_dir_all(&archived_root).unwrap(); + + let active_uuid = Uuid::from_u128(211); + let active_thread_id = + ThreadId::from_string(&active_uuid.to_string()).expect("valid active thread id"); + let active_rollout_path = + sessions_root.join(format!("rollout-2025-01-03T12-00-00-{active_uuid}.jsonl")); + insert_state_db_thread(home, active_thread_id, active_rollout_path.as_path(), false).await; + + let archived_uuid = Uuid::from_u128(212); + let archived_thread_id = + ThreadId::from_string(&archived_uuid.to_string()).expect("valid archived thread id"); + let archived_rollout_path = + archived_root.join(format!("rollout-2025-01-03T11-00-00-{archived_uuid}.jsonl")); + insert_state_db_thread( + home, + archived_thread_id, + archived_rollout_path.as_path(), + true, + ) + .await; + + let page = RolloutRecorder::list_threads( + home, + 10, + None, + ThreadSortKey::CreatedAt, + NO_SOURCE_FILTER, + None, + TEST_PROVIDER, + ) + .await + .expect("thread listing should succeed"); + + assert_eq!(page.items.len(), 1); + assert_eq!(page.items[0].path, active_rollout_path); +} + +#[tokio::test] +async fn list_threads_falls_back_to_files_when_state_db_is_unavailable() { + let temp = TempDir::new().unwrap(); + let home = temp.path(); + let fs_uuid = Uuid::from_u128(301); + write_session_file( + home, + "2025-01-03T13-00-00", + fs_uuid, + 1, + Some(SessionSource::Cli), + ) + .unwrap(); + + let page = RolloutRecorder::list_threads( + home, + 10, + None, + ThreadSortKey::CreatedAt, + NO_SOURCE_FILTER, + None, + TEST_PROVIDER, + ) + .await + .expect("thread listing should succeed"); + + assert_eq!(page.items.len(), 1); + let file_name = page.items[0] + .path + .file_name() + .and_then(|value| value.to_str()) + .expect("rollout file name should be utf8"); + assert!( + file_name.contains(&fs_uuid.to_string()), + "expected file path from filesystem listing, got: {file_name}" + ); +} + #[test] fn rollout_date_parts_extracts_directory_components() { let file_name = OsStr::new("rollout-2025-03-01T09-00-00-123.jsonl"); diff --git a/codex-rs/core/src/state_db.rs b/codex-rs/core/src/state_db.rs index ff95ed946f9..7a0894a327c 100644 --- a/codex-rs/core/src/state_db.rs +++ b/codex-rs/core/src/state_db.rs @@ -181,6 +181,59 @@ pub async fn list_thread_ids_db( } } +/// List thread metadata from SQLite without rollout directory traversal. +#[allow(clippy::too_many_arguments)] +pub async fn list_threads_db( + context: Option<&codex_state::StateRuntime>, + codex_home: &Path, + page_size: usize, + cursor: Option<&Cursor>, + sort_key: ThreadSortKey, + allowed_sources: &[SessionSource], + model_providers: Option<&[String]>, + archived: bool, +) -> Option { + let ctx = context?; + if ctx.codex_home() != codex_home { + warn!( + "state db codex_home mismatch: expected {}, got {}", + ctx.codex_home().display(), + codex_home.display() + ); + } + + let anchor = cursor_to_anchor(cursor); + let allowed_sources: Vec = allowed_sources + .iter() + .map(|value| match serde_json::to_value(value) { + Ok(Value::String(s)) => s, + Ok(other) => other.to_string(), + Err(_) => String::new(), + }) + .collect(); + let model_providers = model_providers.map(<[String]>::to_vec); + match ctx + .list_threads( + page_size, + anchor.as_ref(), + match sort_key { + ThreadSortKey::CreatedAt => codex_state::SortKey::CreatedAt, + ThreadSortKey::UpdatedAt => codex_state::SortKey::UpdatedAt, + }, + allowed_sources.as_slice(), + model_providers.as_deref(), + archived, + ) + .await + { + Ok(page) => Some(page), + Err(err) => { + warn!("state db list_threads failed: {err}"); + None + } + } +} + /// Look up the rollout path for a thread id using SQLite. pub async fn find_rollout_path_by_id( context: Option<&codex_state::StateRuntime>, diff --git a/codex-rs/state/src/extract.rs b/codex-rs/state/src/extract.rs index ad40118086b..bfcd75416f0 100644 --- a/codex-rs/state/src/extract.rs +++ b/codex-rs/state/src/extract.rs @@ -3,14 +3,19 @@ use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; use codex_protocol::models::is_local_image_close_tag_text; use codex_protocol::models::is_local_image_open_tag_text; +use codex_protocol::protocol::ENVIRONMENT_CONTEXT_OPEN_TAG; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::RolloutItem; use codex_protocol::protocol::SessionMetaLine; use codex_protocol::protocol::TurnContextItem; +use codex_protocol::protocol::USER_INSTRUCTIONS_OPEN_TAG; use codex_protocol::protocol::USER_MESSAGE_BEGIN; use serde::Serialize; use serde_json::Value; +const USER_INSTRUCTIONS_PREFIX: &str = "# AGENTS.md instructions for "; +const TURN_ABORTED_OPEN_TAG: &str = ""; + /// Apply a rollout item to the metadata structure. pub fn apply_rollout_item( metadata: &mut ThreadMetadata, @@ -74,14 +79,48 @@ fn apply_event_msg(metadata: &mut ThreadMetadata, event: &EventMsg) { } fn apply_response_item(metadata: &mut ThreadMetadata, item: &ResponseItem) { - if let Some(text) = extract_user_message_text(item) { - metadata.has_user_event = true; - if metadata.title.is_empty() { - metadata.title = text; - } + if !is_user_response_item(item) { + return; + } + metadata.has_user_event = true; + if metadata.title.is_empty() + && let Some(text) = extract_user_message_text(item) + { + metadata.title = text; } } +// TODO(jif) unify once the discussion is settled +fn is_user_response_item(item: &ResponseItem) -> bool { + let ResponseItem::Message { role, content, .. } = item else { + return false; + }; + role == "user" + && !is_user_instructions(content.as_slice()) + && !is_session_prefix_content(content.as_slice()) +} + +fn is_user_instructions(content: &[ContentItem]) -> bool { + if let [ContentItem::InputText { text }] = content { + text.starts_with(USER_INSTRUCTIONS_PREFIX) || text.starts_with(USER_INSTRUCTIONS_OPEN_TAG) + } else { + false + } +} + +fn is_session_prefix_content(content: &[ContentItem]) -> bool { + if let [ContentItem::InputText { text }] = content { + is_session_prefix(text) + } else { + false + } +} + +fn is_session_prefix(text: &str) -> bool { + let lowered = text.trim_start().to_ascii_lowercase(); + lowered.starts_with(ENVIRONMENT_CONTEXT_OPEN_TAG) || lowered.starts_with(TURN_ABORTED_OPEN_TAG) +} + fn extract_user_message_text(item: &ResponseItem) -> Option { let ResponseItem::Message { role, content, .. } = item else { return None; @@ -125,6 +164,7 @@ pub(crate) fn enum_to_string(value: &T) -> String { #[cfg(test)] mod tests { + use super::apply_rollout_item; use super::extract_user_message_text; use crate::model::ThreadMetadata; use chrono::DateTime; @@ -132,6 +172,8 @@ mod tests { use codex_protocol::ThreadId; use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; + use codex_protocol::protocol::RolloutItem; + use codex_protocol::protocol::USER_INSTRUCTIONS_OPEN_TAG; use codex_protocol::protocol::USER_MESSAGE_BEGIN; use pretty_assertions::assert_eq; use std::path::PathBuf; @@ -158,10 +200,69 @@ mod tests { } #[test] - fn diff_fields_detects_changes() { - let id = ThreadId::from_string(&Uuid::now_v7().to_string()).expect("thread id"); + fn user_instructions_do_not_count_as_user_events() { + let mut metadata = metadata_for_test(); + let item = RolloutItem::ResponseItem(ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: format!( + "# AGENTS.md instructions for /tmp\n\n{USER_INSTRUCTIONS_OPEN_TAG}test" + ), + }], + end_turn: None, + phase: None, + }); + + apply_rollout_item(&mut metadata, &item, "test-provider"); + + assert_eq!(metadata.has_user_event, false); + assert_eq!(metadata.title, ""); + } + + #[test] + fn session_prefix_messages_do_not_count_as_user_events() { + let mut metadata = metadata_for_test(); + let item = RolloutItem::ResponseItem(ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "\n {\"cwd\":\"/tmp\"}" + .to_string(), + }], + end_turn: None, + phase: None, + }); + + apply_rollout_item(&mut metadata, &item, "test-provider"); + + assert_eq!(metadata.has_user_event, false); + assert_eq!(metadata.title, ""); + } + + #[test] + fn image_only_user_messages_still_count_as_user_events() { + let mut metadata = metadata_for_test(); + let item = RolloutItem::ResponseItem(ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputImage { + image_url: "https://example.com/image.png".to_string(), + }], + end_turn: None, + phase: None, + }); + + apply_rollout_item(&mut metadata, &item, "test-provider"); + + assert_eq!(metadata.has_user_event, true); + assert_eq!(metadata.title, ""); + } + + fn metadata_for_test() -> ThreadMetadata { + let id = ThreadId::from_string(&Uuid::from_u128(42).to_string()).expect("thread id"); let created_at = DateTime::::from_timestamp(1_735_689_600, 0).expect("timestamp"); - let base = ThreadMetadata { + ThreadMetadata { id, rollout_path: PathBuf::from("/tmp/a.jsonl"), created_at, @@ -169,7 +270,7 @@ mod tests { source: "cli".to_string(), model_provider: "openai".to_string(), cwd: PathBuf::from("/tmp"), - title: "hello".to_string(), + title: String::new(), sandbox_policy: "read-only".to_string(), approval_mode: "on-request".to_string(), tokens_used: 1, @@ -178,7 +279,14 @@ mod tests { git_sha: None, git_branch: None, git_origin_url: None, - }; + } + } + + #[test] + fn diff_fields_detects_changes() { + let mut base = metadata_for_test(); + base.id = ThreadId::from_string(&Uuid::now_v7().to_string()).expect("thread id"); + base.title = "hello".to_string(); let mut other = base.clone(); other.tokens_used = 2; other.title = "world".to_string();