From 2705f206a69a40596afccaafc41cd2dbbcb31b1c Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Mon, 2 Dec 2024 18:40:17 -0800 Subject: [PATCH] refactor: combine if statement + unwrap_or_else into one match --- sqlx-mysql/src/transaction.rs | 11 +++++++---- sqlx-postgres/src/transaction.rs | 11 +++++++---- sqlx-sqlite/src/connection/worker.rs | 15 ++++++++------- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index 953735bf9a..11f56c0cb9 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -22,10 +22,13 @@ impl TransactionManager for MySqlTransactionManager { ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { let depth = conn.inner.transaction_depth; - if statement.is_some() && depth > 0 { - return Err(Error::InvalidSavePointStatement); - } - let statement = statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)); + let statement = match statement { + // custom `BEGIN` statements are not allowed if we're already in a transaction + // (we need to issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), + }; conn.execute(&*statement).await?; if !conn.in_transaction() { return Err(Error::BeginFailed); diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index ec01129d6f..f70961cc19 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -20,10 +20,13 @@ impl TransactionManager for PgTransactionManager { ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { let depth = conn.inner.transaction_depth; - if statement.is_some() && depth > 0 { - return Err(Error::InvalidSavePointStatement); - } - let statement = statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)); + let statement = match statement { + // custom `BEGIN` statements are not allowed if we're already in + // a transaction (we need to issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), + }; let rollback = Rollback::new(conn); rollback.conn.queue_simple_query(&statement)?; diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index ff908001aa..c8e6f0a268 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -184,17 +184,18 @@ impl ConnectionWorker { Command::Begin { tx, statement } => { let depth = conn.transaction_depth; - let statement = if depth == 0 { - statement.unwrap_or_else(|| begin_ansi_transaction_sql(depth)) - } else { - if statement.is_some() { + let statement = match statement { + // custom `BEGIN` statements are not allowed if + // we're already in a transaction (we need to + // issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => { if tx.blocking_send(Err(Error::InvalidSavePointStatement)).is_err() { break; } continue; - } - - begin_ansi_transaction_sql(depth) + }, + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), }; let res =