diff --git a/crates/goose/src/session/session_manager.rs b/crates/goose/src/session/session_manager.rs index 564911acbc1e..572080fa6ba5 100644 --- a/crates/goose/src/session/session_manager.rs +++ b/crates/goose/src/session/session_manager.rs @@ -462,7 +462,8 @@ impl SessionStorage { let options = SqliteConnectOptions::new() .filename(db_path) .create_if_missing(create_if_missing) - .busy_timeout(std::time::Duration::from_secs(5)); + .busy_timeout(std::time::Duration::from_secs(5)) + .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal); sqlx::SqlitePool::connect_with(options).await.map_err(|e| { anyhow::anyhow!( @@ -780,32 +781,38 @@ impl SessionStorage { name: String, session_type: SessionType, ) -> Result { + let mut tx = self.pool.begin().await?; + let today = chrono::Utc::now().format("%Y%m%d").to_string(); - Ok(sqlx::query_as( - r#" - INSERT INTO sessions (id, name, user_set_name, session_type, working_dir, extension_data) - VALUES ( - ? || '_' || CAST(COALESCE(( - SELECT MAX(CAST(SUBSTR(id, 10) AS INTEGER)) - FROM sessions - WHERE id LIKE ? || '_%' - ), 0) + 1 AS TEXT), - ?, - FALSE, - ?, - ?, - '{}' - ) - RETURNING * - "#, - ) - .bind(&today) - .bind(&today) - .bind(&name) - .bind(session_type.to_string()) - .bind(working_dir.to_string_lossy().as_ref()) - .fetch_one(&self.pool) - .await?) + let session = sqlx::query_as( + r#" + INSERT INTO sessions (id, name, user_set_name, session_type, working_dir, extension_data) + VALUES ( + ? || '_' || CAST(COALESCE(( + SELECT MAX(CAST(SUBSTR(id, 10) AS INTEGER)) + FROM sessions + WHERE id LIKE ? || '_%' + ), 0) + 1 AS TEXT), + ?, + FALSE, + ?, + ?, + '{}' + ) + RETURNING * + "#, + ) + .bind(&today) + .bind(&today) + .bind(&name) + .bind(session_type.to_string()) + .bind(working_dir.to_string_lossy().as_ref()) +// .fetch_one(&self.pool) <-- this contributes to the race condition + .fetch_one(&mut *tx) + .await?; + + tx.commit().await?; + Ok(session) } async fn get_session(&self, id: &str, include_messages: bool) -> Result { @@ -1364,4 +1371,179 @@ mod tests { assert!(imported.user_set_name); assert_eq!(imported.working_dir, PathBuf::from("/tmp/test")); } + + /// Test for WAL mode race condition matching build_session() pattern + /// + /// This test closely simulates the actual build_session() flow: + /// 1. Determine if we need to create a new session (session_id is None) + /// 2. Call create_session() to create it + /// 3. Get the returned session_id + /// 4. Immediately call get_session() with that id (like CliSession::new does) + /// + /// This matches the code in builder.rs and mod.rs where sessions are created + /// and immediately read. + #[tokio::test] + async fn test_wal_race_condition_create_then_get() { + use std::time::Duration; + + const NUM_TASKS: usize = 100; + let mut handles = vec![]; + + for i in 0..NUM_TASKS { + let handle = tokio::spawn(async move { + // Wait for all tasks to be ready + + // Simulate build_session() logic: + // Step 1: session_id is None, so we need to create a new session + let session_id: Option = None; + + // Step 2: Create session (like builder.rs) + let session_id = if session_id.is_none() { + let session = + SessionManager::create_session( + PathBuf::from(format!("/tmp/test_{}", i)), + format!("Race test session {}", i), + SessionType::User, + ) + .await + .expect("Failed to create session"); + Some(session.id) + } else { + session_id + }; + + // Step 3: Now simulate CliSession::new() which immediately reads the session + // (like mod.rs:138-149) + let session_id = session_id.unwrap(); + + // This is the critical read that happens in CliSession::new + // It tries to load the conversation from the just-created session + let fetched = SessionManager::get_session(&session_id, true) // include_messages=true like real code + .await; + + match fetched { + Ok(fetched_session) => { + assert_eq!( + fetched_session.id, session_id, + "Session ID mismatch for session {}", + i + ); + println!( + "✅ SUCCESS: Session {} found immediately after creation", + session_id + ); + Ok(session_id) + } + Err(e) => { + // This is the race condition we're testing for + eprintln!("⚠️ RACE DETECTED: Session {} not found immediately after creation: {}", + session_id, e); + Err(format!("Session {} not found: {}", session_id, e)) + } + } + }); + + handles.push(handle); + } + + // Collect results + let mut errors = vec![]; + for handle in handles { + match handle.await.unwrap() { + Ok(_) => {} + Err(e) => errors.push(e), + } + } + + // Give WAL time to checkpoint + tokio::time::sleep(Duration::from_millis(100)).await; + + // Report any race conditions detected + if !errors.is_empty() { + panic!( + "WAL race condition detected in {} out of {} tasks:\n{}", + errors.len(), + NUM_TASKS, + errors.join("\n") + ); + } + } + + /// Test the exact pattern used in CliSession::new with block_in_place + /// + /// This test simulates the blocking pattern used in the actual code to see + /// if it exacerbates the WAL race condition. + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn test_wal_race_with_blocking_pattern() { + let temp_dir = TempDir::new().unwrap(); + let db_path = temp_dir.path().join("test_blocking_race.db"); + let storage = Arc::new(SessionStorage::create(&db_path).await.unwrap()); + + const NUM_ITERATIONS: usize = 100; + let mut handles = vec![]; + + for i in 0..NUM_ITERATIONS { + let storage = Arc::clone(&storage); + + let handle = tokio::spawn(async move { + // Create a session + let description = format!("Blocking test {}", i); + let created = storage + .create_session( + PathBuf::from(format!("/tmp/test_{}", i)), + description, + SessionType::User, + ) + .await + .unwrap(); + + // Simulate CliSession::new's blocking pattern + let session_id = created.id.clone(); + let fetched = tokio::task::block_in_place(|| { + tokio::runtime::Handle::current() + .block_on(async { storage.get_session(&session_id, false).await }) + }); + + match fetched { + Ok(_) => { + println!( + "✅ SUCCESS (blocking): Session {} found immediately after creation", + session_id + ); + Ok(created.id) + } + Err(e) => { + eprintln!( + "⚠️ RACE DETECTED with block_in_place: Session {} not found: {}", + session_id, e + ); + Err(format!( + "Session {} not found with blocking: {}", + session_id, e + )) + } + } + }); + + handles.push(handle); + } + + // Collect results + let mut errors = vec![]; + for handle in handles { + match handle.await.unwrap() { + Ok(_) => {} + Err(e) => errors.push(e), + } + } + + if !errors.is_empty() { + panic!( + "WAL race condition detected with blocking pattern in {} out of {} iterations:\n{}", + errors.len(), + NUM_ITERATIONS, + errors.join("\n") + ); + } + } }