Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 99 additions & 8 deletions crates/goose/src/session/session_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ impl SessionStorage {
}

async fn import_legacy_session(pool: &Pool<Sqlite>, session: &Session) -> Result<()> {
let mut tx = pool.begin().await?;
let mut tx = pool.begin_with("BEGIN IMMEDIATE").await?;

let recipe_json = match &session.recipe {
Some(recipe) => Some(serde_json::to_string(recipe)?),
Expand Down Expand Up @@ -724,7 +724,7 @@ impl SessionStorage {
}

async fn run_migrations(pool: &Pool<Sqlite>) -> Result<()> {
let mut tx = pool.begin().await?;
let mut tx = pool.begin_with("BEGIN IMMEDIATE").await?;

let current_version = Self::get_schema_version(&mut tx).await?;

Expand Down Expand Up @@ -899,7 +899,7 @@ impl SessionStorage {
session_type: SessionType,
) -> Result<Session> {
let pool = self.pool().await?;
let mut tx = pool.begin().await?;
let mut tx = pool.begin_with("BEGIN IMMEDIATE").await?;

let today = chrono::Utc::now().format("%Y%m%d").to_string();
let session = sqlx::query_as(
Expand Down Expand Up @@ -1071,7 +1071,7 @@ impl SessionStorage {
}

let pool = self.pool().await?;
let mut tx = pool.begin().await?;
let mut tx = pool.begin_with("BEGIN IMMEDIATE").await?;
q = q.bind(&builder.session_id);
q.execute(&mut *tx).await?;

Expand Down Expand Up @@ -1116,7 +1116,7 @@ impl SessionStorage {

async fn add_message(&self, session_id: &str, message: &Message) -> Result<()> {
let pool = self.pool().await?;
let mut tx = pool.begin().await?;
let mut tx = pool.begin_with("BEGIN IMMEDIATE").await?;

let metadata_json = serde_json::to_string(&message.metadata)?;

Expand Down Expand Up @@ -1154,7 +1154,7 @@ impl SessionStorage {
session_id: &str,
conversation: &Conversation,
) -> Result<()> {
let mut tx = pool.begin().await?;
let mut tx = pool.begin_with("BEGIN IMMEDIATE").await?;

sqlx::query("DELETE FROM messages WHERE session_id = ?")
.bind(session_id)
Expand Down Expand Up @@ -1237,7 +1237,7 @@ impl SessionStorage {

async fn delete_session(&self, session_id: &str) -> Result<()> {
let pool = self.pool().await?;
let mut tx = pool.begin().await?;
let mut tx = pool.begin_with("BEGIN IMMEDIATE").await?;

let exists =
sqlx::query_scalar::<_, bool>("SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?)")
Expand Down Expand Up @@ -1416,7 +1416,7 @@ impl SessionStorage {
) -> crate::conversation::message::MessageMetadata,
{
let pool = self.pool().await?;
let mut tx = pool.begin().await?;
let mut tx = pool.begin_with("BEGIN IMMEDIATE").await?;

let current_metadata_json = sqlx::query_scalar::<_, String>(
"SELECT metadata_json FROM messages WHERE message_id = ? AND session_id = ?",
Expand Down Expand Up @@ -1455,6 +1455,97 @@ mod tests {

const NUM_CONCURRENT_SESSIONS: i32 = 10;

async fn run_lock_upgrade_attempt(
pool: Pool<Sqlite>,
session_id: String,
begin_statement: &'static str,
worker_id: i32,
barrier: Option<Arc<tokio::sync::Barrier>>,
) -> anyhow::Result<()> {
let mut tx = pool.begin_with(begin_statement).await?;

sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM sessions WHERE id = ?")
.bind(&session_id)
.fetch_one(&mut *tx)
.await?;

if let Some(barrier) = barrier {
barrier.wait().await;
}

sqlx::query("UPDATE sessions SET total_tokens = ? WHERE id = ?")
.bind(worker_id)
.bind(&session_id)
.execute(&mut *tx)
.await?;

tx.commit().await?;
Ok(())
}

async fn run_lock_upgrade_race(
pool: Pool<Sqlite>,
session_id: String,
begin_statement: &'static str,
use_barrier: bool,
) -> Vec<anyhow::Result<()>> {
let barrier = if use_barrier {
Some(Arc::new(tokio::sync::Barrier::new(2)))
} else {
None
};
let mut handles = Vec::new();

for worker_id in 0..2 {
let pool = pool.clone();
let session_id = session_id.clone();
let barrier = barrier.clone();
handles.push(tokio::spawn(async move {
run_lock_upgrade_attempt(pool, session_id, begin_statement, worker_id, barrier)
.await
}));
}

let mut results = Vec::new();
for handle in handles {
results.push(handle.await.expect("lock-upgrade task panicked"));
}
results
}

#[tokio::test]
async fn test_begin_immediate_prevents_lock_upgrade_deadlock() {
let temp_dir = TempDir::new().unwrap();
let session_manager = SessionManager::new(temp_dir.path().to_path_buf());

let session = session_manager
.create_session(
PathBuf::from("/tmp/lock-upgrade-test"),
"Lock Upgrade Session".to_string(),
SessionType::User,
)
.await
.unwrap();

let pool = session_manager.storage().pool.clone();

let results = run_lock_upgrade_race(pool.clone(), session.id.clone(), "BEGIN", true).await;
assert!(
results.iter().any(Result::is_err),
"BEGIN (DEFERRED) should cause SQLITE_BUSY when two tasks try to upgrade SHARED → RESERVED"
);

let results = run_lock_upgrade_race(pool, session.id, "BEGIN IMMEDIATE", false).await;
assert!(
results.iter().all(Result::is_ok),
"BEGIN IMMEDIATE should serialize contention without SQLITE_BUSY: {:?}",
results
.iter()
.filter_map(|r| r.as_ref().err().map(ToString::to_string))
.collect::<Vec<_>>()
);
}

#[tokio::test]
async fn test_concurrent_session_creation() {
let temp_dir = TempDir::new().unwrap();
Expand Down