diff --git a/crates/goose/src/session/session_manager.rs b/crates/goose/src/session/session_manager.rs index 89955c6d546e..a11e4b12735e 100644 --- a/crates/goose/src/session/session_manager.rs +++ b/crates/goose/src/session/session_manager.rs @@ -665,7 +665,7 @@ impl SessionStorage { } async fn import_legacy_session(pool: &Pool, session: &Session) -> Result<()> { - let mut tx = pool.begin().await?; + let mut tx = pool.begin_with("BEGIN IMMEDIATE").await?; let recipe_json = match &session.recipe { Some(recipe) => Some(serde_json::to_string(recipe)?), @@ -724,7 +724,7 @@ impl SessionStorage { } async fn run_migrations(pool: &Pool) -> Result<()> { - let mut tx = pool.begin().await?; + let mut tx = pool.begin_with("BEGIN IMMEDIATE").await?; let current_version = Self::get_schema_version(&mut tx).await?; @@ -899,7 +899,7 @@ impl SessionStorage { session_type: SessionType, ) -> Result { let pool = self.pool().await?; - let mut tx = pool.begin().await?; + let mut tx = pool.begin_with("BEGIN IMMEDIATE").await?; let today = chrono::Utc::now().format("%Y%m%d").to_string(); let session = sqlx::query_as( @@ -1071,7 +1071,7 @@ impl SessionStorage { } let pool = self.pool().await?; - let mut tx = pool.begin().await?; + let mut tx = pool.begin_with("BEGIN IMMEDIATE").await?; q = q.bind(&builder.session_id); q.execute(&mut *tx).await?; @@ -1116,7 +1116,7 @@ impl SessionStorage { async fn add_message(&self, session_id: &str, message: &Message) -> Result<()> { let pool = self.pool().await?; - let mut tx = pool.begin().await?; + let mut tx = pool.begin_with("BEGIN IMMEDIATE").await?; let metadata_json = serde_json::to_string(&message.metadata)?; @@ -1154,7 +1154,7 @@ impl SessionStorage { session_id: &str, conversation: &Conversation, ) -> Result<()> { - let mut tx = pool.begin().await?; + let mut tx = pool.begin_with("BEGIN IMMEDIATE").await?; sqlx::query("DELETE FROM messages WHERE session_id = ?") .bind(session_id) @@ -1237,7 +1237,7 @@ impl SessionStorage { async fn delete_session(&self, session_id: &str) -> Result<()> { let pool = self.pool().await?; - let mut tx = pool.begin().await?; + let mut tx = pool.begin_with("BEGIN IMMEDIATE").await?; let exists = sqlx::query_scalar::<_, bool>("SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?)") @@ -1416,7 +1416,7 @@ impl SessionStorage { ) -> crate::conversation::message::MessageMetadata, { let pool = self.pool().await?; - let mut tx = pool.begin().await?; + let mut tx = pool.begin_with("BEGIN IMMEDIATE").await?; let current_metadata_json = sqlx::query_scalar::<_, String>( "SELECT metadata_json FROM messages WHERE message_id = ? AND session_id = ?", @@ -1455,6 +1455,97 @@ mod tests { const NUM_CONCURRENT_SESSIONS: i32 = 10; + async fn run_lock_upgrade_attempt( + pool: Pool, + session_id: String, + begin_statement: &'static str, + worker_id: i32, + barrier: Option>, + ) -> anyhow::Result<()> { + let mut tx = pool.begin_with(begin_statement).await?; + + sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM sessions WHERE id = ?") + .bind(&session_id) + .fetch_one(&mut *tx) + .await?; + + if let Some(barrier) = barrier { + barrier.wait().await; + } + + sqlx::query("UPDATE sessions SET total_tokens = ? WHERE id = ?") + .bind(worker_id) + .bind(&session_id) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + Ok(()) + } + + async fn run_lock_upgrade_race( + pool: Pool, + session_id: String, + begin_statement: &'static str, + use_barrier: bool, + ) -> Vec> { + let barrier = if use_barrier { + Some(Arc::new(tokio::sync::Barrier::new(2))) + } else { + None + }; + let mut handles = Vec::new(); + + for worker_id in 0..2 { + let pool = pool.clone(); + let session_id = session_id.clone(); + let barrier = barrier.clone(); + handles.push(tokio::spawn(async move { + run_lock_upgrade_attempt(pool, session_id, begin_statement, worker_id, barrier) + .await + })); + } + + let mut results = Vec::new(); + for handle in handles { + results.push(handle.await.expect("lock-upgrade task panicked")); + } + results + } + + #[tokio::test] + async fn test_begin_immediate_prevents_lock_upgrade_deadlock() { + let temp_dir = TempDir::new().unwrap(); + let session_manager = SessionManager::new(temp_dir.path().to_path_buf()); + + let session = session_manager + .create_session( + PathBuf::from("/tmp/lock-upgrade-test"), + "Lock Upgrade Session".to_string(), + SessionType::User, + ) + .await + .unwrap(); + + let pool = session_manager.storage().pool.clone(); + + let results = run_lock_upgrade_race(pool.clone(), session.id.clone(), "BEGIN", true).await; + assert!( + results.iter().any(Result::is_err), + "BEGIN (DEFERRED) should cause SQLITE_BUSY when two tasks try to upgrade SHARED → RESERVED" + ); + + let results = run_lock_upgrade_race(pool, session.id, "BEGIN IMMEDIATE", false).await; + assert!( + results.iter().all(Result::is_ok), + "BEGIN IMMEDIATE should serialize contention without SQLITE_BUSY: {:?}", + results + .iter() + .filter_map(|r| r.as_ref().err().map(ToString::to_string)) + .collect::>() + ); + } + #[tokio::test] async fn test_concurrent_session_creation() { let temp_dir = TempDir::new().unwrap();