From 472b7c281265fa7bd32201ffc987b8286a8e4576 Mon Sep 17 00:00:00 2001 From: Marco Napetti Date: Tue, 21 Sep 2021 13:51:46 +0200 Subject: [PATCH 01/12] Stream implementation --- src/database/connection.rs | 15 ++++ src/database/db_connection.rs | 3 + src/database/db_transaction.rs | 5 ++ src/database/mod.rs | 4 ++ src/database/stream.rs | 121 +++++++++++++++++++++++++++++++++ src/driver/sqlx_mysql.rs | 14 +++- src/driver/sqlx_postgres.rs | 14 +++- src/driver/sqlx_sqlite.rs | 14 +++- src/executor/insert.rs | 6 +- src/executor/paginator.rs | 4 +- src/executor/select.rs | 41 +++++++++++ tests/stream_tests.rs | 39 +++++++++++ 12 files changed, 272 insertions(+), 8 deletions(-) create mode 100644 src/database/stream.rs create mode 100644 tests/stream_tests.rs diff --git a/src/database/connection.rs b/src/database/connection.rs index d2a6f9018..df59d2951 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -110,6 +110,21 @@ impl ConnectionTrait for DatabaseConnection { } } + #[cfg(feature = "sqlx-dep")] + async fn stream(&self, stmt: Statement) -> Result { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(_) => panic!("Mock"),//TODO: can it be permitted? How? + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } + /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. async fn transaction(&self, _callback: F) -> Result> diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 615c0c3fd..02476e439 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -11,6 +11,9 @@ pub trait ConnectionTrait: Sync { async fn query_all(&self, stmt: Statement) -> Result, DbErr>; + #[cfg(feature = "sqlx-dep")] + async fn stream(&self, stmt: Statement) -> Result; + /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. async fn transaction(&self, callback: F) -> Result> diff --git a/src/database/db_transaction.rs b/src/database/db_transaction.rs index 000f6f609..69f2b68b2 100644 --- a/src/database/db_transaction.rs +++ b/src/database/db_transaction.rs @@ -222,6 +222,11 @@ impl<'a> ConnectionTrait for DatabaseTransaction<'a> { _res.map_err(sqlx_error_to_query_err) } + #[cfg(feature = "sqlx-dep")] + async fn stream(&self, _stmt: Statement) -> Result { + todo!(); + } + /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. async fn transaction(&self, _callback: F) -> Result> diff --git a/src/database/mod.rs b/src/database/mod.rs index 296b4d0bb..464b76563 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -5,6 +5,8 @@ mod statement; mod transaction; mod db_connection; mod db_transaction; +#[cfg(feature = "sqlx-dep")] +mod stream; pub use connection::*; #[cfg(feature = "mock")] @@ -13,6 +15,8 @@ pub use statement::*; pub use transaction::*; pub use db_connection::*; pub use db_transaction::*; +#[cfg(feature = "sqlx-dep")] +pub use stream::*; use crate::DbErr; diff --git a/src/database/stream.rs b/src/database/stream.rs new file mode 100644 index 000000000..a9112c1cd --- /dev/null +++ b/src/database/stream.rs @@ -0,0 +1,121 @@ +use std::{pin::Pin, sync::Arc, task::Poll}; + +use futures::{Stream, TryStreamExt}; + +use sqlx::{pool::PoolConnection, Executor}; + +use crate::{sqlx_error_to_query_err, DbErr, QueryResult, Statement}; + +enum Connection { + #[cfg(feature = "sqlx-mysql")] + MySql(PoolConnection), + #[cfg(feature = "sqlx-postgres")] + Postgres(PoolConnection), + #[cfg(feature = "sqlx-sqlite")] + Sqlite(PoolConnection), +} + +pub struct QueryStream { + stmt: Arc, + conn: Arc, + stream: Option>>>>, +} + +#[cfg(feature = "sqlx-mysql")] +impl From<(PoolConnection, Statement)> for QueryStream { + fn from((conn, stmt): (PoolConnection, Statement)) -> Self { + QueryStream { + stmt: Arc::new(stmt), + conn: Arc::new(Connection::MySql(conn)), + stream: None + } + } +} + +#[cfg(feature = "sqlx-postgres")] +impl From<(PoolConnection, Statement)> for QueryStream { + fn from((conn, stmt): (PoolConnection, Statement)) -> Self { + QueryStream { + stmt: Arc::new(stmt), + conn: Arc::new(Connection::Postgres(conn)), + stream: None + } + } +} + +#[cfg(feature = "sqlx-sqlite")] +impl From<(PoolConnection, Statement)> for QueryStream { + fn from((conn, stmt): (PoolConnection, Statement)) -> Self { + QueryStream { + stmt: Arc::new(stmt), + conn: Arc::new(Connection::Sqlite(conn)), + stream: None + } + } +} + +impl std::fmt::Debug for QueryStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "QueryStream") + } +} + +impl QueryStream { + fn get_conn(&mut self) -> Option<&'static mut Connection> { + // this is safe since the connection is owned and the stream, that references the connection, is owned too, so they die tougheter + unsafe { std::mem::transmute(Arc::get_mut(&mut self.conn)) } + } + fn get_stmt(&self) -> &'static Statement { + // this is safe since the statement is owned and the stream, that references the statement, is owned too, so they die tougheter + unsafe { std::mem::transmute(self.stmt.as_ref()) } + } + fn init(&mut self) { + match self.get_conn() { + #[cfg(feature = "sqlx-mysql")] + Some(Connection::MySql(c)) => { + let query = crate::driver::sqlx_mysql::sqlx_query(self.get_stmt()); + self.stream = Some(Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(sqlx_error_to_query_err) + )); + }, + #[cfg(feature = "sqlx-postgres")] + Some(Connection::Postgres(c)) => { + let query = crate::driver::sqlx_postgres::sqlx_query(self.get_stmt()); + self.stream = Some(Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(sqlx_error_to_query_err) + )); + }, + #[cfg(feature = "sqlx-sqlite")] + Some(Connection::Sqlite(c)) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(self.get_stmt()); + self.stream = Some(Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(sqlx_error_to_query_err) + )); + }, + _ => unreachable!(), + } + } +} + +impl Stream for QueryStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + if this.stream.is_none() { + this.init(); + } + if let Some(stream) = this.stream.as_mut() { + stream.as_mut().poll_next(cx) + } + else { + unreachable!(); + } + } +} diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index 2235b98b3..5ad7c8428 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -5,7 +5,7 @@ use sqlx::{Connection, MySql, MySqlPool, mysql::{MySqlArguments, MySqlQueryResul sea_query::sea_query_driver_mysql!(); use sea_query_driver_mysql::bind_query; -use crate::{DatabaseConnection, DatabaseTransaction, Statement, TransactionError, debug_print, error::*, executor::*}; +use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*}; use super::sqlx_common::*; @@ -91,6 +91,18 @@ impl SqlxMySqlPoolConnection { } } + pub async fn stream(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + if let Ok(conn) = self.pool.acquire().await { + Ok(QueryStream::from((conn, stmt))) + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + pub async fn transaction(&self, callback: F) -> Result> where F: for<'b> FnOnce(&'b DatabaseTransaction<'_>) -> Pin> + Send + 'b>> + Send + Sync, diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index 72cb871a6..d737cdd54 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -5,7 +5,7 @@ use sqlx::{Connection, PgPool, Postgres, postgres::{PgArguments, PgQueryResult, sea_query::sea_query_driver_postgres!(); use sea_query_driver_postgres::bind_query; -use crate::{DatabaseConnection, DatabaseTransaction, Statement, TransactionError, debug_print, error::*, executor::*}; +use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*}; use super::sqlx_common::*; @@ -91,6 +91,18 @@ impl SqlxPostgresPoolConnection { } } + pub async fn stream(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + if let Ok(conn) = self.pool.acquire().await { + Ok(QueryStream::from((conn, stmt))) + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + pub async fn transaction(&self, callback: F) -> Result> where F: for<'b> FnOnce(&'b DatabaseTransaction<'_>) -> Pin> + Send + 'b>> + Send + Sync, diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index 66cf2d5df..93643475e 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -5,7 +5,7 @@ use sqlx::{Connection, Sqlite, SqlitePool, sqlite::{SqliteArguments, SqliteQuery sea_query::sea_query_driver_sqlite!(); use sea_query_driver_sqlite::bind_query; -use crate::{DatabaseConnection, DatabaseTransaction, Statement, TransactionError, debug_print, error::*, executor::*}; +use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*}; use super::sqlx_common::*; @@ -91,6 +91,18 @@ impl SqlxSqlitePoolConnection { } } + pub async fn stream(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + if let Ok(conn) = self.pool.acquire().await { + Ok(QueryStream::from((conn, stmt))) + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + pub async fn transaction(&self, callback: F) -> Result> where F: for<'b> FnOnce(&'b DatabaseTransaction<'_>) -> Pin> + Send + 'b>> + Send + Sync, diff --git a/src/executor/insert.rs b/src/executor/insert.rs index 079560e60..8b4e2864d 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -1,4 +1,4 @@ -use crate::{ActiveModelTrait, DbBackend, ConnectionTrait, EntityTrait, Insert, PrimaryKeyTrait, Statement, TryFromU64, error::*}; +use crate::{ActiveModelTrait, ConnectionTrait, EntityTrait, Insert, PrimaryKeyTrait, Statement, TryFromU64, error::*}; use sea_query::InsertStatement; use std::{future::Future, marker::PhantomData}; @@ -36,7 +36,7 @@ where // so that self is dropped before entering await let mut query = self.query; #[cfg(feature = "sqlx-postgres")] - if db.get_database_backend() == DbBackend::Postgres && !db.is_mock_connection() { + if db.get_database_backend() == crate::DbBackend::Postgres && !db.is_mock_connection() { use crate::{sea_query::Query, Iterable}; if ::PrimaryKey::iter().count() > 0 { query.returning( @@ -88,7 +88,7 @@ where type ValueTypeOf = as PrimaryKeyTrait>::ValueType; let last_insert_id = match db.get_database_backend() { #[cfg(feature = "sqlx-postgres")] - DbBackend::Postgres if !db.is_mock_connection() => { + crate::DbBackend::Postgres if !db.is_mock_connection() => { use crate::{sea_query::Iden, Iterable}; let cols = PrimaryKey::::iter() .map(|col| col.to_string()) diff --git a/src/executor/paginator.rs b/src/executor/paginator.rs index 6b5943f45..3dea86f82 100644 --- a/src/executor/paginator.rs +++ b/src/executor/paginator.rs @@ -1,4 +1,4 @@ -use crate::{DbBackend, ConnectionTrait, SelectorTrait, error::*}; +use crate::{ConnectionTrait, SelectorTrait, error::*}; use async_stream::stream; use futures::Stream; use sea_query::{Alias, Expr, SelectStatement}; @@ -67,7 +67,7 @@ where }; let num_items = match builder { #[cfg(feature = "sqlx-postgres")] - DbBackend::Postgres if !self.db.is_mock_connection() => result.try_get::("", "num_items")? as usize, + crate::DbBackend::Postgres if !self.db.is_mock_connection() => result.try_get::("", "num_items")? as usize, _ => result.try_get::("", "num_items")? as usize, }; Ok(num_items) diff --git a/src/executor/select.rs b/src/executor/select.rs index 31521b5d9..44d703af3 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -1,4 +1,8 @@ +#[cfg(feature = "sqlx-dep")] +use std::pin::Pin; use crate::{ConnectionTrait, EntityTrait, FromQueryResult, IdenStatic, Iterable, JsonValue, ModelTrait, Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, SelectTwoMany, Statement, error::*}; +#[cfg(feature = "sqlx-dep")] +use futures::{Stream, TryStreamExt}; use sea_query::SelectStatement; use std::marker::PhantomData; @@ -109,6 +113,14 @@ where self.into_model().all(db).await } + #[cfg(feature = "sqlx-dep")] + pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result> + 'b, DbErr> + where + C: ConnectionTrait, + { + self.into_model().stream(db).await + } + pub fn paginate( self, db: &C, @@ -164,6 +176,14 @@ where self.into_model().all(db).await } + #[cfg(feature = "sqlx-dep")] + pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result), DbErr>> + 'b, DbErr> + where + C: ConnectionTrait, + { + self.into_model().stream(db).await + } + pub fn paginate( self, db: &C, @@ -211,6 +231,14 @@ where self.into_model().one(db).await } + #[cfg(feature = "sqlx-dep")] + pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result), DbErr>> + 'b, DbErr> + where + C: ConnectionTrait, + { + self.into_model().stream(db).await + } + pub async fn all( self, db: &C, @@ -256,6 +284,19 @@ where Ok(models) } + #[cfg(feature = "sqlx-dep")] + pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result> + 'b>>, DbErr> + where + C: ConnectionTrait, + S: 'b, + { + let builder = db.get_database_backend(); + let stream = db.stream(builder.build(&self.query)).await?; + Ok(Box::pin(stream.and_then(|row| { + futures::future::ready(S::from_raw_query_result(row)) + }))) + } + pub fn paginate(self, db: &C, page_size: usize) -> Paginator<'_, C, S> where C: ConnectionTrait { Paginator { diff --git a/tests/stream_tests.rs b/tests/stream_tests.rs new file mode 100644 index 000000000..16313f114 --- /dev/null +++ b/tests/stream_tests.rs @@ -0,0 +1,39 @@ +pub mod common; + +pub use common::{bakery_chain::*, setup::*, TestContext}; +pub use sea_orm::entity::*; +pub use sea_orm::{QueryFilter, ConnectionTrait}; + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +pub async fn stream() { + use futures::StreamExt; + + let ctx = TestContext::new("stream").await; + + let bakery = bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .save(&ctx.db) + .await + .expect("could not insert bakery"); + + let result = Bakery::find_by_id(bakery.id.clone().unwrap()) + .stream(&ctx.db) + .await + .unwrap() + .next() + .await + .unwrap() + .unwrap(); + + assert_eq!(result.id, bakery.id.unwrap()); + + ctx.delete().await; +} From 605058f2d855af9a0aa34ac15e9268567dea3333 Mon Sep 17 00:00:00 2001 From: Marco Napetti Date: Tue, 21 Sep 2021 14:58:16 +0200 Subject: [PATCH 02/12] use ouroboros to cover self-references --- Cargo.toml | 1 + src/database/stream.rs | 112 +++++++++++++++++------------------------ 2 files changed, 48 insertions(+), 65 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3e5a9682d..abe43062a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1", optional = true } sqlx = { version = "^0.5", optional = true } uuid = { version = "0.8", features = ["serde", "v4"], optional = true } +ouroboros = "0.10" [dev-dependencies] smol = { version = "^1.2" } diff --git a/src/database/stream.rs b/src/database/stream.rs index a9112c1cd..8977b86e5 100644 --- a/src/database/stream.rs +++ b/src/database/stream.rs @@ -1,4 +1,4 @@ -use std::{pin::Pin, sync::Arc, task::Poll}; +use std::{pin::Pin, task::Poll}; use futures::{Stream, TryStreamExt}; @@ -15,42 +15,33 @@ enum Connection { Sqlite(PoolConnection), } +#[ouroboros::self_referencing] pub struct QueryStream { - stmt: Arc, - conn: Arc, - stream: Option>>>>, + stmt: Statement, + conn: Connection, + #[borrows(mut conn, stmt)] + #[not_covariant] + stream: Pin> + 'this>>, } #[cfg(feature = "sqlx-mysql")] impl From<(PoolConnection, Statement)> for QueryStream { fn from((conn, stmt): (PoolConnection, Statement)) -> Self { - QueryStream { - stmt: Arc::new(stmt), - conn: Arc::new(Connection::MySql(conn)), - stream: None - } + QueryStream::build(stmt, Connection::MySql(conn)) } } #[cfg(feature = "sqlx-postgres")] impl From<(PoolConnection, Statement)> for QueryStream { fn from((conn, stmt): (PoolConnection, Statement)) -> Self { - QueryStream { - stmt: Arc::new(stmt), - conn: Arc::new(Connection::Postgres(conn)), - stream: None - } + QueryStream::build(stmt, Connection::Postgres(conn)) } } #[cfg(feature = "sqlx-sqlite")] impl From<(PoolConnection, Statement)> for QueryStream { fn from((conn, stmt): (PoolConnection, Statement)) -> Self { - QueryStream { - stmt: Arc::new(stmt), - conn: Arc::new(Connection::Sqlite(conn)), - stream: None - } + QueryStream::build(stmt, Connection::Sqlite(conn)) } } @@ -61,45 +52,42 @@ impl std::fmt::Debug for QueryStream { } impl QueryStream { - fn get_conn(&mut self) -> Option<&'static mut Connection> { - // this is safe since the connection is owned and the stream, that references the connection, is owned too, so they die tougheter - unsafe { std::mem::transmute(Arc::get_mut(&mut self.conn)) } - } - fn get_stmt(&self) -> &'static Statement { - // this is safe since the statement is owned and the stream, that references the statement, is owned too, so they die tougheter - unsafe { std::mem::transmute(self.stmt.as_ref()) } - } - fn init(&mut self) { - match self.get_conn() { - #[cfg(feature = "sqlx-mysql")] - Some(Connection::MySql(c)) => { - let query = crate::driver::sqlx_mysql::sqlx_query(self.get_stmt()); - self.stream = Some(Box::pin( - c.fetch(query) - .map_ok(Into::into) - .map_err(sqlx_error_to_query_err) - )); - }, - #[cfg(feature = "sqlx-postgres")] - Some(Connection::Postgres(c)) => { - let query = crate::driver::sqlx_postgres::sqlx_query(self.get_stmt()); - self.stream = Some(Box::pin( - c.fetch(query) - .map_ok(Into::into) - .map_err(sqlx_error_to_query_err) - )); - }, - #[cfg(feature = "sqlx-sqlite")] - Some(Connection::Sqlite(c)) => { - let query = crate::driver::sqlx_sqlite::sqlx_query(self.get_stmt()); - self.stream = Some(Box::pin( - c.fetch(query) - .map_ok(Into::into) - .map_err(sqlx_error_to_query_err) - )); + fn build(stmt: Statement, conn: Connection) -> Self { + QueryStreamBuilder { + stmt, + conn, + stream_builder: |conn, stmt| { + match conn { + #[cfg(feature = "sqlx-mysql")] + Connection::MySql(c) => { + let query = crate::driver::sqlx_mysql::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(sqlx_error_to_query_err) + ) + }, + #[cfg(feature = "sqlx-postgres")] + Connection::Postgres(c) => { + let query = crate::driver::sqlx_postgres::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(sqlx_error_to_query_err) + ) + }, + #[cfg(feature = "sqlx-sqlite")] + Connection::Sqlite(c) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(sqlx_error_to_query_err) + ) + }, + } }, - _ => unreachable!(), - } + }.build() } } @@ -108,14 +96,8 @@ impl Stream for QueryStream { fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { let this = self.get_mut(); - if this.stream.is_none() { - this.init(); - } - if let Some(stream) = this.stream.as_mut() { + this.with_stream_mut(|stream| { stream.as_mut().poll_next(cx) - } - else { - unreachable!(); - } + }) } } From 8a14f0659f20b3803e82446476a52e43da7250fb Mon Sep 17 00:00:00 2001 From: Marco Napetti Date: Tue, 21 Sep 2021 15:49:43 +0200 Subject: [PATCH 03/12] Reduce test size --- tests/stream_tests.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/stream_tests.rs b/tests/stream_tests.rs index 16313f114..9a75fafe4 100644 --- a/tests/stream_tests.rs +++ b/tests/stream_tests.rs @@ -10,7 +10,7 @@ pub use sea_orm::{QueryFilter, ConnectionTrait}; feature = "sqlx-sqlite", feature = "sqlx-postgres" ))] -pub async fn stream() { +pub async fn stream() -> Result<(), DbErr> { use futures::StreamExt; let ctx = TestContext::new("stream").await; @@ -21,19 +21,18 @@ pub async fn stream() { ..Default::default() } .save(&ctx.db) - .await - .expect("could not insert bakery"); + .await?; let result = Bakery::find_by_id(bakery.id.clone().unwrap()) .stream(&ctx.db) - .await - .unwrap() + .await? .next() .await - .unwrap() - .unwrap(); + .unwrap()?; assert_eq!(result.id, bakery.id.unwrap()); ctx.delete().await; + + Ok(()) } From 3e9f2205930196fdaaa4040d551b504ff505b9fa Mon Sep 17 00:00:00 2001 From: Marco Napetti Date: Tue, 21 Sep 2021 16:01:12 +0200 Subject: [PATCH 04/12] Reduce test size --- tests/stream_tests.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/stream_tests.rs b/tests/stream_tests.rs index 9a75fafe4..a6c22dc89 100644 --- a/tests/stream_tests.rs +++ b/tests/stream_tests.rs @@ -2,7 +2,8 @@ pub mod common; pub use common::{bakery_chain::*, setup::*, TestContext}; pub use sea_orm::entity::*; -pub use sea_orm::{QueryFilter, ConnectionTrait}; +pub use sea_orm::{QueryFilter, ConnectionTrait, DbErr}; +use futures::StreamExt; #[sea_orm_macros::test] #[cfg(any( @@ -11,8 +12,6 @@ pub use sea_orm::{QueryFilter, ConnectionTrait}; feature = "sqlx-postgres" ))] pub async fn stream() -> Result<(), DbErr> { - use futures::StreamExt; - let ctx = TestContext::new("stream").await; let bakery = bakery::ActiveModel { From 09e6e1ee5ed7cd24f76345f8368ae96294ef86d7 Mon Sep 17 00:00:00 2001 From: Marco Napetti Date: Sun, 26 Sep 2021 22:26:12 +0200 Subject: [PATCH 05/12] Complete transaction rewrite + streams --- src/database/connection.rs | 42 +++- src/database/db_connection.rs | 31 ++- src/database/db_transaction.rs | 396 ++++++++++++++++++++++----------- src/database/mod.rs | 2 - src/database/stream.rs | 61 ++--- src/driver/mock.rs | 19 +- src/driver/sqlx_mysql.rs | 30 ++- src/driver/sqlx_postgres.rs | 30 ++- src/driver/sqlx_sqlite.rs | 30 ++- src/entity/active_model.rs | 24 +- src/executor/delete.rs | 20 +- src/executor/insert.rs | 22 +- src/executor/paginator.rs | 4 +- src/executor/select.rs | 76 +++---- src/executor/update.rs | 34 +-- tests/transaction_tests.rs | 4 +- 16 files changed, 518 insertions(+), 307 deletions(-) diff --git a/src/database/connection.rs b/src/database/connection.rs index df59d2951..349694486 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -1,5 +1,5 @@ -use std::{pin::Pin, future::Future}; -use crate::{DatabaseTransaction, ConnectionTrait, ExecResult, QueryResult, Statement, StatementBuilder, TransactionError, error::*}; +use std::{future::Future, pin::Pin}; +use crate::{DatabaseTransaction, ConnectionTrait, ExecResult, QueryResult, QueryStream, Statement, StatementBuilder, TransactionError, error::*}; use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; #[cfg_attr(not(feature = "mock"), derive(Clone))] @@ -53,7 +53,7 @@ impl std::fmt::Debug for DatabaseConnection { } #[async_trait::async_trait] -impl ConnectionTrait for DatabaseConnection { +impl<'a> ConnectionTrait<'a> for DatabaseConnection { fn get_database_backend(&self) -> DbBackend { match self { #[cfg(feature = "sqlx-mysql")] @@ -77,7 +77,7 @@ impl ConnectionTrait for DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt).await, + DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), } } @@ -91,7 +91,7 @@ impl ConnectionTrait for DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt).await, + DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), } } @@ -105,13 +105,12 @@ impl ConnectionTrait for DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt).await, + DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), } } - #[cfg(feature = "sqlx-dep")] - async fn stream(&self, stmt: Statement) -> Result { + async fn stream(&'a self, stmt: Statement) -> Result, DbErr> { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await, @@ -120,16 +119,32 @@ impl ConnectionTrait for DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(_) => panic!("Mock"),//TODO: can it be permitted? How? + DatabaseConnection::MockDatabaseConnection(conn) => Ok(QueryStream::from((conn, stmt))), + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } + + async fn begin(&'a self) -> Result, DbErr> { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => DatabaseTransaction::new_mock(conn).await, DatabaseConnection::Disconnected => panic!("Disconnected"), } } /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. - async fn transaction(&self, _callback: F) -> Result> + async fn transaction(&'a self, _callback: F) -> Result> where - F: for<'c> FnOnce(&'c DatabaseTransaction<'_>) -> Pin> + Send + 'c>> + Send + Sync, + F: for<'c> FnOnce(&'c DatabaseTransaction<'a>) -> Pin> + Send + 'c>> + Send + Sync, + // F: FnOnce(&DatabaseTransaction<'a>) -> Fut + Send, + // Fut: Future> + Send, T: Send, E: std::error::Error + Send, { @@ -141,7 +156,10 @@ impl ConnectionTrait for DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(_) => unimplemented!(), //TODO: support transaction in mock connection + DatabaseConnection::MockDatabaseConnection(conn) => { + let transaction = DatabaseTransaction::new_mock(conn).await.map_err(|e| TransactionError::Connection(e))?; + transaction.run(_callback).await + }, DatabaseConnection::Disconnected => panic!("Disconnected"), } } diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 02476e439..0a531f174 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -1,8 +1,22 @@ -use std::{pin::Pin, future::Future}; -use crate::{DatabaseTransaction, DbBackend, DbErr, ExecResult, QueryResult, Statement, TransactionError}; +use std::{future::Future, pin::Pin}; +use crate::{DatabaseTransaction, DbBackend, DbErr, ExecResult, MockDatabaseConnection, QueryResult, QueryStream, Statement, TransactionError}; +#[cfg(feature = "sqlx-dep")] +use sqlx::pool::PoolConnection; + +pub(crate) enum InnerConnection<'a> { + #[cfg(feature = "sqlx-mysql")] + MySql(PoolConnection), + #[cfg(feature = "sqlx-postgres")] + Postgres(PoolConnection), + #[cfg(feature = "sqlx-sqlite")] + Sqlite(PoolConnection), + #[cfg(feature = "mock")] + Mock(&'a MockDatabaseConnection), + Transaction(Box<&'a DatabaseTransaction<'a>>), +} #[async_trait::async_trait] -pub trait ConnectionTrait: Sync { +pub trait ConnectionTrait<'a>: Sync { fn get_database_backend(&self) -> DbBackend; async fn execute(&self, stmt: Statement) -> Result; @@ -11,14 +25,17 @@ pub trait ConnectionTrait: Sync { async fn query_all(&self, stmt: Statement) -> Result, DbErr>; - #[cfg(feature = "sqlx-dep")] - async fn stream(&self, stmt: Statement) -> Result; + async fn stream(&'a self, stmt: Statement) -> Result, DbErr>; + + async fn begin(&'a self) -> Result, DbErr>; /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. - async fn transaction(&self, callback: F) -> Result> + async fn transaction(&'a self, callback: F) -> Result> where - F: for<'c> FnOnce(&'c DatabaseTransaction<'_>) -> Pin> + Send + 'c>> + Send + Sync, + F: for<'c> FnOnce(&'c DatabaseTransaction<'a>) -> Pin> + Send + 'c>> + Send + Sync, + // F: FnOnce(&DatabaseTransaction<'a>) -> Fut + Send, + // Fut: Future> + Send, T: Send, E: std::error::Error + Send; diff --git a/src/database/db_transaction.rs b/src/database/db_transaction.rs index 69f2b68b2..419bb6a9e 100644 --- a/src/database/db_transaction.rs +++ b/src/database/db_transaction.rs @@ -1,153 +1,304 @@ -use std::{pin::Pin, future::Future}; -use crate::{DbBackend, ConnectionTrait, DbErr, ExecResult, QueryResult, Statement, debug_print}; +use std::{cell::UnsafeCell, future::Future, pin::Pin}; +use crate::{ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult, QueryStream, Statement, debug_print}; +use futures::{Stream, TryStreamExt}; #[cfg(feature = "sqlx-dep")] use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err}; #[cfg(feature = "sqlx-dep")] -use sqlx::Connection; +use sqlx::{pool::PoolConnection, Executor, TransactionManager}; -#[cfg(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite"))] -use futures::lock::Mutex; +// a Transaction is just a sugar for a connection where START TRANSACTION has been executed +pub struct DatabaseTransaction<'a> { + // using Option we don't even need an "open" flag + conn: Option>>, +} -#[derive(Debug)] -pub enum DatabaseTransaction<'a> { +impl<'a> std::fmt::Debug for DatabaseTransaction<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DatabaseTransaction") + } +} + +impl<'a> DatabaseTransaction<'a> { #[cfg(feature = "sqlx-mysql")] - SqlxMySqlTransaction(Mutex>), + pub(crate) async fn new_mysql(inner: PoolConnection) -> Result, DbErr> { + Self::build(InnerConnection::MySql(inner)).await + } + #[cfg(feature = "sqlx-postgres")] - SqlxPostgresTransaction(Mutex>), - #[cfg(feature = "sqlx-sqlite")] - SqlxSqliteTransaction(Mutex>), - #[cfg(not(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))] - None(&'a ()), -} + pub(crate) async fn new_postgres(inner: PoolConnection) -> Result, DbErr> { + Self::build(InnerConnection::Postgres(inner)).await + } -#[cfg(feature = "sqlx-mysql")] -impl<'a> From> for DatabaseTransaction<'a> { - fn from(inner: sqlx::Transaction<'a, sqlx::MySql>) -> Self { - DatabaseTransaction::SqlxMySqlTransaction(Mutex::new(inner)) + #[cfg(feature = "sqlx-sqlite")] + pub(crate) async fn new_sqlite(inner: PoolConnection) -> Result, DbErr> { + Self::build(InnerConnection::Sqlite(inner)).await } -} -#[cfg(feature = "sqlx-postgres")] -impl<'a> From> for DatabaseTransaction<'a> { - fn from(inner: sqlx::Transaction<'a, sqlx::Postgres>) -> Self { - DatabaseTransaction::SqlxPostgresTransaction(Mutex::new(inner)) + #[cfg(feature = "mock")] + pub(crate) async fn new_mock(inner: &'a crate::MockDatabaseConnection) -> Result, DbErr> { + Self::build(InnerConnection::Mock(inner)).await } -} -#[cfg(feature = "sqlx-sqlite")] -impl<'a> From> for DatabaseTransaction<'a> { - fn from(inner: sqlx::Transaction<'a, sqlx::Sqlite>) -> Self { - DatabaseTransaction::SqlxSqliteTransaction(Mutex::new(inner)) + async fn build(conn: InnerConnection<'a>) -> Result, DbErr> { + let mut res = DatabaseTransaction { + conn: Some(UnsafeCell::new(conn)), + }; + match res.conn.as_mut().map(|c| c.get_mut()) { + #[cfg(feature = "sqlx-mysql")] + Some(InnerConnection::MySql(c)) => { + ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-postgres")] + Some(InnerConnection::Postgres(c)) => { + ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-sqlite")] + Some(InnerConnection::Sqlite(c)) => { + ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? + }, + // should we do something for mocked connections? + #[cfg(feature = "mock")] + Some(InnerConnection::Mock(_)) => {}, + // nested transactions should already have been started + Some(InnerConnection::Transaction(_)) => {}, + _ => unreachable!(), + } + Ok(res) } -} -#[allow(dead_code)] -impl<'a> DatabaseTransaction<'a> { - pub(crate) async fn run(self, callback: F) -> Result> + pub(crate) async fn run(self, callback: F) -> Result> where F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + Send + 'b>> + Send + Sync, + // F: FnOnce(&DatabaseTransaction<'a>) -> Fut + Send, + // Fut: Future> + Send, T: Send, E: std::error::Error + Send, { let res = callback(&self).await.map_err(|e| TransactionError::Transaction(e)); if res.is_ok() { - self.commit().await?; + self.commit().await.map_err(|e| TransactionError::Connection(e))?; } else { - self.rollback().await?; + self.rollback().await.map_err(|e| TransactionError::Connection(e))?; } res } - async fn commit(self) -> Result<(), TransactionError> - where E: std::error::Error { - match self { + pub async fn commit(mut self) -> Result<(), DbErr> { + match self.conn.take().map(|c| c.into_inner()) { #[cfg(feature = "sqlx-mysql")] - DatabaseTransaction::SqlxMySqlTransaction(inner) => { - let transaction = inner.into_inner(); - transaction.commit().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string()))) + Some(InnerConnection::MySql(ref mut c)) => { + ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? }, #[cfg(feature = "sqlx-postgres")] - DatabaseTransaction::SqlxPostgresTransaction(inner) => { - let transaction = inner.into_inner(); - transaction.commit().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string()))) + Some(InnerConnection::Postgres(ref mut c)) => { + ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? }, #[cfg(feature = "sqlx-sqlite")] - DatabaseTransaction::SqlxSqliteTransaction(inner) => { - let transaction = inner.into_inner(); - transaction.commit().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string()))) + Some(InnerConnection::Sqlite(ref mut c)) => { + ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? }, - #[cfg(not(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))] - _ => unimplemented!(), + Some(InnerConnection::Transaction(c)) => c.ref_commit().await?, + //Should we do something for mocked &connections? + #[cfg(feature = "mock")] + Some(InnerConnection::Mock(_)) => {}, + _ => unreachable!(), } + Ok(()) } - async fn rollback(self) -> Result<(), TransactionError> - where E: std::error::Error { - match self { + // non destructive commit + fn ref_commit(&'a self) -> Pin> + Send + 'a>> { + Box::pin(async move { + if self.conn.is_some() { + match self.get_conn() { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? + }, + InnerConnection::Transaction(c) => c.ref_commit().await?, + //Should we do something for mocked &connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {}, + } + } + Ok(()) + }) + } + + pub async fn rollback(mut self) -> Result<(), DbErr> { + match self.conn.take().map(|c| c.into_inner()) { + #[cfg(feature = "sqlx-mysql")] + Some(InnerConnection::MySql(ref mut c)) => { + ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-postgres")] + Some(InnerConnection::Postgres(ref mut c)) => { + ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-sqlite")] + Some(InnerConnection::Sqlite(ref mut c)) => { + ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? + }, + Some(InnerConnection::Transaction(c)) => c.ref_rollback().await?, + //Should we do something for mocked &connections? + #[cfg(feature = "mock")] + Some(InnerConnection::Mock(_)) => {}, + _ => unreachable!(), + } + Ok(()) + } + + // non destructive rollback + fn ref_rollback(&'a self) -> Pin> + Send + 'a>> { + Box::pin(async move { + if self.conn.is_some() { + match self.get_conn() { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? + }, + InnerConnection::Transaction(c) => c.ref_rollback().await?, + //Should we do something for mocked &connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {}, + } + } + Ok(()) + }) + } + + pub(crate) fn fetch<'b>(&'b self, stmt: &'b Statement) -> Pin> + 'b>> { + match self.get_conn() { #[cfg(feature = "sqlx-mysql")] - DatabaseTransaction::SqlxMySqlTransaction(inner) => { - let transaction = inner.into_inner(); - transaction.rollback().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string()))) + InnerConnection::MySql(inner) => { + let query = crate::driver::sqlx_mysql::sqlx_query(stmt); + Box::pin(inner.fetch(query) + .map_ok(Into::into) + .map_err(sqlx_error_to_query_err)) }, #[cfg(feature = "sqlx-postgres")] - DatabaseTransaction::SqlxPostgresTransaction(inner) => { - let transaction = inner.into_inner(); - transaction.rollback().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string()))) + InnerConnection::Postgres(inner) => { + let query = crate::driver::sqlx_postgres::sqlx_query(stmt); + Box::pin(inner.fetch(query) + .map_ok(Into::into) + .map_err(sqlx_error_to_query_err)) }, #[cfg(feature = "sqlx-sqlite")] - DatabaseTransaction::SqlxSqliteTransaction(inner) => { - let transaction = inner.into_inner(); - transaction.rollback().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string()))) + InnerConnection::Sqlite(inner) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); + Box::pin(inner.fetch(query) + .map_ok(Into::into) + .map_err(sqlx_error_to_query_err)) + }, + InnerConnection::Transaction(inner) => { + inner.fetch(stmt) }, - #[cfg(not(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))] - _ => unimplemented!(), + #[cfg(feature = "mock")] + InnerConnection::Mock(inner) => { + inner.fetch(stmt) + }, + } + } + + // the rollback is queued and will be performed on next async operation, like returning the connection to the pool + fn start_rollback(&self) { + if let Some(conn) = self.conn.as_ref() { + match unsafe { &mut *conn.get() } { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + ::TransactionManager::start_rollback(c); + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + ::TransactionManager::start_rollback(c); + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + ::TransactionManager::start_rollback(c); + }, + InnerConnection::Transaction(c) => { + c.start_rollback(); + } + //Should we do something for mocked &connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {}, + } } } + + fn get_conn(&self) -> &mut InnerConnection<'a> { + unsafe { &mut *self.conn.as_ref().map(|c| c.get()).unwrap() } + } } +impl<'a> Drop for DatabaseTransaction<'a> { + fn drop(&mut self) { + self.start_rollback(); + } +} + +// this is needed since sqlite connections aren't sync +unsafe impl<'a> Sync for DatabaseTransaction<'a> {} + #[async_trait::async_trait] -impl<'a> ConnectionTrait for DatabaseTransaction<'a> { +impl<'a> ConnectionTrait<'a> for DatabaseTransaction<'a> { fn get_database_backend(&self) -> DbBackend { - match self { + match self.conn.as_ref().map(|c| unsafe { &*c.get() }) { #[cfg(feature = "sqlx-mysql")] - DatabaseTransaction::SqlxMySqlTransaction(_) => DbBackend::MySql, + Some(InnerConnection::MySql(_)) => DbBackend::MySql, #[cfg(feature = "sqlx-postgres")] - DatabaseTransaction::SqlxPostgresTransaction(_) => DbBackend::Postgres, + Some(InnerConnection::Postgres(_)) => DbBackend::Postgres, #[cfg(feature = "sqlx-sqlite")] - DatabaseTransaction::SqlxSqliteTransaction(_) => DbBackend::Sqlite, - #[cfg(not(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))] - _ => unimplemented!(), + Some(InnerConnection::Sqlite(_)) => DbBackend::Sqlite, + #[cfg(feature = "mock")] + Some(InnerConnection::Mock(c)) => c.get_database_backend(), + Some(InnerConnection::Transaction(c)) => c.get_database_backend(), + _ => unreachable!(), } } async fn execute(&self, stmt: Statement) -> Result { debug_print!("{}", stmt); - let _res = match self { + let _res = match self.get_conn() { #[cfg(feature = "sqlx-mysql")] - DatabaseTransaction::SqlxMySqlTransaction(conn) => { + InnerConnection::MySql(conn) => { let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.execute(&mut *conn).await + query.execute(conn).await .map(Into::into) }, #[cfg(feature = "sqlx-postgres")] - DatabaseTransaction::SqlxPostgresTransaction(conn) => { + InnerConnection::Postgres(conn) => { let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.execute(&mut *conn).await + query.execute(conn).await .map(Into::into) }, #[cfg(feature = "sqlx-sqlite")] - DatabaseTransaction::SqlxSqliteTransaction(conn) => { + InnerConnection::Sqlite(conn) => { let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.execute(&mut *conn).await + query.execute(conn).await .map(Into::into) }, - #[cfg(not(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))] - _ => unimplemented!(), + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.execute(stmt), + InnerConnection::Transaction(conn) => return conn.execute(stmt).await, }; #[cfg(feature = "sqlx-dep")] _res.map_err(sqlx_error_to_exec_err) @@ -156,30 +307,28 @@ impl<'a> ConnectionTrait for DatabaseTransaction<'a> { async fn query_one(&self, stmt: Statement) -> Result, DbErr> { debug_print!("{}", stmt); - let _res = match self { + let _res = match self.get_conn() { #[cfg(feature = "sqlx-mysql")] - DatabaseTransaction::SqlxMySqlTransaction(conn) => { + InnerConnection::MySql(conn) => { let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.fetch_one(&mut *conn).await + query.fetch_one(conn).await .map(|row| Some(row.into())) }, #[cfg(feature = "sqlx-postgres")] - DatabaseTransaction::SqlxPostgresTransaction(conn) => { + InnerConnection::Postgres(conn) => { let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.fetch_one(&mut *conn).await + query.fetch_one(conn).await .map(|row| Some(row.into())) }, #[cfg(feature = "sqlx-sqlite")] - DatabaseTransaction::SqlxSqliteTransaction(conn) => { + InnerConnection::Sqlite(conn) => { let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.fetch_one(&mut *conn).await + query.fetch_one(conn).await .map(|row| Some(row.into())) }, - #[cfg(not(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))] - _ => unimplemented!(), + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.query_one(stmt), + InnerConnection::Transaction(conn) => return conn.query_one(stmt).await, }; #[cfg(feature = "sqlx-dep")] if let Err(sqlx::Error::RowNotFound) = _res { @@ -193,70 +342,53 @@ impl<'a> ConnectionTrait for DatabaseTransaction<'a> { async fn query_all(&self, stmt: Statement) -> Result, DbErr> { debug_print!("{}", stmt); - let _res = match self { + let _res = match self.get_conn() { #[cfg(feature = "sqlx-mysql")] - DatabaseTransaction::SqlxMySqlTransaction(conn) => { + InnerConnection::MySql(conn) => { let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.fetch_all(&mut *conn).await + query.fetch_all(conn).await .map(|rows| rows.into_iter().map(|r| r.into()).collect()) }, #[cfg(feature = "sqlx-postgres")] - DatabaseTransaction::SqlxPostgresTransaction(conn) => { + InnerConnection::Postgres(conn) => { let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.fetch_all(&mut *conn).await + query.fetch_all(conn).await .map(|rows| rows.into_iter().map(|r| r.into()).collect()) }, #[cfg(feature = "sqlx-sqlite")] - DatabaseTransaction::SqlxSqliteTransaction(conn) => { + InnerConnection::Sqlite(conn) => { let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.fetch_all(&mut *conn).await + query.fetch_all(conn).await .map(|rows| rows.into_iter().map(|r| r.into()).collect()) }, - #[cfg(not(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))] - _ => unimplemented!(), + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.query_all(stmt), + InnerConnection::Transaction(conn) => return conn.query_all(stmt).await, }; #[cfg(feature = "sqlx-dep")] _res.map_err(sqlx_error_to_query_err) } - #[cfg(feature = "sqlx-dep")] - async fn stream(&self, _stmt: Statement) -> Result { - todo!(); + async fn stream(&'a self, stmt: Statement) -> Result, DbErr> { + Ok(QueryStream::from((self, stmt))) + } + + async fn begin(&'a self) -> Result, DbErr> { + DatabaseTransaction::build(InnerConnection::Transaction(Box::new(self))).await } /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. - async fn transaction(&self, _callback: F) -> Result> + async fn transaction(&'a self, _callback: F) -> Result> where - F: for<'c> FnOnce(&'c DatabaseTransaction<'_>) -> Pin> + Send + 'c>> + Send + Sync, + F: for<'c> FnOnce(&'c DatabaseTransaction<'a>) -> Pin> + Send + 'c>> + Send + Sync, + // F: FnOnce(&DatabaseTransaction<'a>) -> Fut + Send, + // Fut: Future> + Send, T: Send, E: std::error::Error + Send, { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseTransaction::SqlxMySqlTransaction(conn) => { - let mut conn = conn.lock().await; - let transaction = DatabaseTransaction::from(conn.begin().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string())))?); - transaction.run(_callback).await - }, - #[cfg(feature = "sqlx-postgres")] - DatabaseTransaction::SqlxPostgresTransaction(conn) => { - let mut conn = conn.lock().await; - let transaction = DatabaseTransaction::from(conn.begin().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string())))?); - transaction.run(_callback).await - }, - #[cfg(feature = "sqlx-sqlite")] - DatabaseTransaction::SqlxSqliteTransaction(conn) => { - let mut conn = conn.lock().await; - let transaction = DatabaseTransaction::from(conn.begin().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string())))?); - transaction.run(_callback).await - }, - #[cfg(not(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))] - _ => unimplemented!(), - } + let transaction = self.begin().await.map_err(|e| TransactionError::Connection(e))?; + transaction.run(_callback).await } } diff --git a/src/database/mod.rs b/src/database/mod.rs index 464b76563..c8db4b99c 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -5,7 +5,6 @@ mod statement; mod transaction; mod db_connection; mod db_transaction; -#[cfg(feature = "sqlx-dep")] mod stream; pub use connection::*; @@ -15,7 +14,6 @@ pub use statement::*; pub use transaction::*; pub use db_connection::*; pub use db_transaction::*; -#[cfg(feature = "sqlx-dep")] pub use stream::*; use crate::DbErr; diff --git a/src/database/stream.rs b/src/database/stream.rs index 8977b86e5..fdc83ca17 100644 --- a/src/database/stream.rs +++ b/src/database/stream.rs @@ -4,62 +4,66 @@ use futures::{Stream, TryStreamExt}; use sqlx::{pool::PoolConnection, Executor}; -use crate::{sqlx_error_to_query_err, DbErr, QueryResult, Statement}; - -enum Connection { - #[cfg(feature = "sqlx-mysql")] - MySql(PoolConnection), - #[cfg(feature = "sqlx-postgres")] - Postgres(PoolConnection), - #[cfg(feature = "sqlx-sqlite")] - Sqlite(PoolConnection), -} +use crate::{DatabaseTransaction, DbErr, InnerConnection, QueryResult, Statement, sqlx_error_to_query_err}; #[ouroboros::self_referencing] -pub struct QueryStream { +pub struct QueryStream<'a> { stmt: Statement, - conn: Connection, + conn: InnerConnection<'a>, #[borrows(mut conn, stmt)] #[not_covariant] stream: Pin> + 'this>>, } #[cfg(feature = "sqlx-mysql")] -impl From<(PoolConnection, Statement)> for QueryStream { +impl<'a> From<(PoolConnection, Statement)> for QueryStream<'a> { fn from((conn, stmt): (PoolConnection, Statement)) -> Self { - QueryStream::build(stmt, Connection::MySql(conn)) + QueryStream::build(stmt, InnerConnection::MySql(conn)) } } #[cfg(feature = "sqlx-postgres")] -impl From<(PoolConnection, Statement)> for QueryStream { +impl<'a> From<(PoolConnection, Statement)> for QueryStream<'a> { fn from((conn, stmt): (PoolConnection, Statement)) -> Self { - QueryStream::build(stmt, Connection::Postgres(conn)) + QueryStream::build(stmt, InnerConnection::Postgres(conn)) } } #[cfg(feature = "sqlx-sqlite")] -impl From<(PoolConnection, Statement)> for QueryStream { +impl<'a> From<(PoolConnection, Statement)> for QueryStream<'a> { fn from((conn, stmt): (PoolConnection, Statement)) -> Self { - QueryStream::build(stmt, Connection::Sqlite(conn)) + QueryStream::build(stmt, InnerConnection::Sqlite(conn)) + } +} + +#[cfg(feature = "mock")] +impl<'a> From<(&'a crate::MockDatabaseConnection, Statement)> for QueryStream<'a> { + fn from((conn, stmt): (&'a crate::MockDatabaseConnection, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::Mock(conn)) } } -impl std::fmt::Debug for QueryStream { +impl<'a> From<(&'a DatabaseTransaction<'a>, Statement)> for QueryStream<'a> { + fn from((conn, stmt): (&'a DatabaseTransaction<'a>, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::Transaction(Box::new(conn))) + } +} + +impl<'a> std::fmt::Debug for QueryStream<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "QueryStream") } } -impl QueryStream { - fn build(stmt: Statement, conn: Connection) -> Self { +impl<'a> QueryStream<'a> { + fn build(stmt: Statement, conn: InnerConnection<'a>) -> QueryStream<'a> { QueryStreamBuilder { stmt, conn, stream_builder: |conn, stmt| { match conn { #[cfg(feature = "sqlx-mysql")] - Connection::MySql(c) => { + InnerConnection::MySql(c) => { let query = crate::driver::sqlx_mysql::sqlx_query(stmt); Box::pin( c.fetch(query) @@ -68,7 +72,7 @@ impl QueryStream { ) }, #[cfg(feature = "sqlx-postgres")] - Connection::Postgres(c) => { + InnerConnection::Postgres(c) => { let query = crate::driver::sqlx_postgres::sqlx_query(stmt); Box::pin( c.fetch(query) @@ -77,7 +81,7 @@ impl QueryStream { ) }, #[cfg(feature = "sqlx-sqlite")] - Connection::Sqlite(c) => { + InnerConnection::Sqlite(c) => { let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); Box::pin( c.fetch(query) @@ -85,13 +89,20 @@ impl QueryStream { .map_err(sqlx_error_to_query_err) ) }, + #[cfg(feature = "mock")] + InnerConnection::Mock(c) => { + c.fetch(stmt) + }, + InnerConnection::Transaction(c) => { + c.fetch(stmt) + }, } }, }.build() } } -impl Stream for QueryStream { +impl<'a> Stream for QueryStream<'a> { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { diff --git a/src/driver/mock.rs b/src/driver/mock.rs index 823ddb32d..4aeb024e6 100644 --- a/src/driver/mock.rs +++ b/src/driver/mock.rs @@ -2,11 +2,11 @@ use crate::{ debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult, Statement, Transaction, }; -use std::fmt::Debug; -use std::sync::{ +use std::{fmt::Debug, pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, Mutex, -}; +}}; +use futures::Stream; #[derive(Debug)] pub struct MockDatabaseConnector; @@ -86,25 +86,32 @@ impl MockDatabaseConnection { &self.mocker } - pub async fn execute(&self, statement: Statement) -> Result { + pub fn execute(&self, statement: Statement) -> Result { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); self.mocker.lock().unwrap().execute(counter, statement) } - pub async fn query_one(&self, statement: Statement) -> Result, DbErr> { + pub fn query_one(&self, statement: Statement) -> Result, DbErr> { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); let result = self.mocker.lock().unwrap().query(counter, statement)?; Ok(result.into_iter().next()) } - pub async fn query_all(&self, statement: Statement) -> Result, DbErr> { + pub fn query_all(&self, statement: Statement) -> Result, DbErr> { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); self.mocker.lock().unwrap().query(counter, statement) } + pub fn fetch(&self, statement: &Statement) -> Pin>>> { + match self.query_all(statement.clone()) { + Ok(v) => Box::pin(futures::stream::iter(v.into_iter().map(|r| Ok(r)))), + Err(e) => Box::pin(futures::stream::iter(Some(Err(e)).into_iter())), + } + } + pub fn get_database_backend(&self) -> DbBackend { self.mocker.lock().unwrap().get_database_backend() } diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index 5ad7c8428..e6189392a 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -1,6 +1,6 @@ -use std::{pin::Pin, future::Future}; +use std::{future::Future, pin::Pin}; -use sqlx::{Connection, MySql, MySqlPool, mysql::{MySqlArguments, MySqlQueryResult, MySqlRow}}; +use sqlx::{MySql, MySqlPool, mysql::{MySqlArguments, MySqlQueryResult, MySqlRow}}; sea_query::sea_query_driver_mysql!(); use sea_query_driver_mysql::bind_query; @@ -91,7 +91,7 @@ impl SqlxMySqlPoolConnection { } } - pub async fn stream(&self, stmt: Statement) -> Result { + pub async fn stream(&self, stmt: Statement) -> Result, DbErr> { debug_print!("{}", stmt); if let Ok(conn) = self.pool.acquire().await { @@ -103,18 +103,26 @@ impl SqlxMySqlPoolConnection { } } - pub async fn transaction(&self, callback: F) -> Result> + pub async fn begin(&self) -> Result, DbErr> { + if let Ok(conn) = self.pool.acquire().await { + DatabaseTransaction::new_mysql(conn).await + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn transaction<'a, F, T, E/*, Fut*/>(&'a self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction<'_>) -> Pin> + Send + 'b>> + Send + Sync, + F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + Send + 'b>> + Send + Sync, + // F: FnOnce(&DatabaseTransaction<'_>) -> Fut + Send, + // Fut: Future> + Send, T: Send, E: std::error::Error + Send, { - if let Ok(conn) = &mut self.pool.acquire().await { - let transaction = DatabaseTransaction::from( - conn.begin().await.map_err(|e| { - TransactionError::Connection(DbErr::Query(e.to_string())) - })? - ); + if let Ok(conn) = self.pool.acquire().await { + let transaction = DatabaseTransaction::new_mysql(conn).await.map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await } else { Err(TransactionError::Connection(DbErr::Query( diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index d737cdd54..0e4a7dda1 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -1,6 +1,6 @@ -use std::{pin::Pin, future::Future}; +use std::{future::Future, pin::Pin}; -use sqlx::{Connection, PgPool, Postgres, postgres::{PgArguments, PgQueryResult, PgRow}}; +use sqlx::{PgPool, Postgres, postgres::{PgArguments, PgQueryResult, PgRow}}; sea_query::sea_query_driver_postgres!(); use sea_query_driver_postgres::bind_query; @@ -91,7 +91,7 @@ impl SqlxPostgresPoolConnection { } } - pub async fn stream(&self, stmt: Statement) -> Result { + pub async fn stream(&self, stmt: Statement) -> Result, DbErr> { debug_print!("{}", stmt); if let Ok(conn) = self.pool.acquire().await { @@ -103,18 +103,26 @@ impl SqlxPostgresPoolConnection { } } - pub async fn transaction(&self, callback: F) -> Result> + pub async fn begin(&self) -> Result, DbErr> { + if let Ok(conn) = self.pool.acquire().await { + DatabaseTransaction::new_postgres(conn).await + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn transaction<'a, F, T, E/*, Fut*/>(&'a self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction<'_>) -> Pin> + Send + 'b>> + Send + Sync, + F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + Send + 'b>> + Send + Sync, + // F: FnOnce(&DatabaseTransaction<'_>) -> Fut + Send, + // Fut: Future> + Send, T: Send, E: std::error::Error + Send, { - if let Ok(conn) = &mut self.pool.acquire().await { - let transaction = DatabaseTransaction::from( - conn.begin().await.map_err(|e| { - TransactionError::Connection(DbErr::Query(e.to_string())) - })? - ); + if let Ok(conn) = self.pool.acquire().await { + let transaction = DatabaseTransaction::new_postgres(conn).await.map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await } else { Err(TransactionError::Connection(DbErr::Query( diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index 93643475e..9b168a6d9 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -1,6 +1,6 @@ -use std::{pin::Pin, future::Future}; +use std::{future::Future, pin::Pin}; -use sqlx::{Connection, Sqlite, SqlitePool, sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow}}; +use sqlx::{Sqlite, SqlitePool, sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow}}; sea_query::sea_query_driver_sqlite!(); use sea_query_driver_sqlite::bind_query; @@ -91,7 +91,7 @@ impl SqlxSqlitePoolConnection { } } - pub async fn stream(&self, stmt: Statement) -> Result { + pub async fn stream(&self, stmt: Statement) -> Result, DbErr> { debug_print!("{}", stmt); if let Ok(conn) = self.pool.acquire().await { @@ -103,18 +103,26 @@ impl SqlxSqlitePoolConnection { } } - pub async fn transaction(&self, callback: F) -> Result> + pub async fn begin(&self) -> Result, DbErr> { + if let Ok(conn) = self.pool.acquire().await { + DatabaseTransaction::new_sqlite(conn).await + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn transaction<'a, F, T, E/*, Fut*/>(&'a self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction<'_>) -> Pin> + Send + 'b>> + Send + Sync, + F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + Send + 'b>> + Send + Sync, + // F: FnOnce(&DatabaseTransaction<'_>) -> Fut + Send, + // Fut: Future> + Send, T: Send, E: std::error::Error + Send, { - if let Ok(conn) = &mut self.pool.acquire().await { - let transaction = DatabaseTransaction::from( - conn.begin().await.map_err(|e| { - TransactionError::Connection(DbErr::Query(e.to_string())) - })? - ); + if let Ok(conn) = self.pool.acquire().await { + let transaction = DatabaseTransaction::new_sqlite(conn).await.map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await } else { Err(TransactionError::Connection(DbErr::Query( diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index a7cb00553..752e6f088 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -67,10 +67,11 @@ pub trait ActiveModelTrait: Clone + Debug { fn default() -> Self; - async fn insert(self, db: &C) -> Result + async fn insert<'a, 'b: 'a, C>(self, db: &'a C) -> Result where ::Model: IntoActiveModel, - C: ConnectionTrait, + C: ConnectionTrait<'b>, + Self: 'a, { let am = self; let exec = ::insert(am).exec(db); @@ -91,19 +92,22 @@ pub trait ActiveModelTrait: Clone + Debug { } } - async fn update(self, db: &C) -> Result - where C: ConnectionTrait { + async fn update<'a, 'b: 'a, C>(self, db: &'a C) -> Result + where + C: ConnectionTrait<'b>, + Self: 'a, + { let exec = Self::Entity::update(self).exec(db); exec.await } /// Insert the model if primary key is unset, update otherwise. /// Only works if the entity has auto increment primary key. - async fn save(self, db: &C) -> Result + async fn save<'a, 'b: 'a, C>(self, db: &'a C) -> Result where - Self: ActiveModelBehavior, + Self: ActiveModelBehavior + 'a, ::Model: IntoActiveModel, - C: ConnectionTrait, + C: ConnectionTrait<'b>, { let mut am = self; am = ActiveModelBehavior::before_save(am); @@ -125,10 +129,10 @@ pub trait ActiveModelTrait: Clone + Debug { } /// Delete an active model by its primary key - async fn delete(self, db: &C) -> Result + async fn delete<'a, C>(self, db: &'a C) -> Result where - Self: ActiveModelBehavior, - C: ConnectionTrait, + Self: ActiveModelBehavior + 'a, + C: ConnectionTrait<'a>, { let mut am = self; am = ActiveModelBehavior::before_delete(am); diff --git a/src/executor/delete.rs b/src/executor/delete.rs index 34c848e23..85b37cb09 100644 --- a/src/executor/delete.rs +++ b/src/executor/delete.rs @@ -20,7 +20,7 @@ where self, db: &'a C, ) -> impl Future> + 'a - where C: ConnectionTrait { + where C: ConnectionTrait<'a> { // so that self is dropped before entering await exec_delete_only(self.query, db) } @@ -34,7 +34,7 @@ where self, db: &'a C, ) -> impl Future> + 'a - where C: ConnectionTrait { + where C: ConnectionTrait<'a> { // so that self is dropped before entering await exec_delete_only(self.query, db) } @@ -45,27 +45,27 @@ impl Deleter { Self { query } } - pub fn exec( + pub fn exec<'a, C>( self, - db: &C, + db: &'a C, ) -> impl Future> + '_ - where C: ConnectionTrait { + where C: ConnectionTrait<'a> { let builder = db.get_database_backend(); exec_delete(builder.build(&self.query), db) } } -async fn exec_delete_only( +async fn exec_delete_only<'a, C>( query: DeleteStatement, - db: &C, + db: &'a C, ) -> Result -where C: ConnectionTrait { +where C: ConnectionTrait<'a> { Deleter::new(query).exec(db).await } // Only Statement impl Send -async fn exec_delete(statement: Statement, db: &C) -> Result -where C: ConnectionTrait { +async fn exec_delete<'a, C>(statement: Statement, db: &C) -> Result +where C: ConnectionTrait<'a> { let result = db.execute(statement).await?; Ok(DeleteResult { rows_affected: result.rows_affected(), diff --git a/src/executor/insert.rs b/src/executor/insert.rs index 8b4e2864d..4f699ac84 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -1,6 +1,6 @@ use crate::{ActiveModelTrait, ConnectionTrait, EntityTrait, Insert, PrimaryKeyTrait, Statement, TryFromU64, error::*}; use sea_query::InsertStatement; -use std::{future::Future, marker::PhantomData}; +use std::marker::PhantomData; #[derive(Clone, Debug)] pub struct Inserter @@ -24,12 +24,12 @@ where A: ActiveModelTrait, { #[allow(unused_mut)] - pub fn exec<'a, C>( + pub async fn exec<'a, 'b: 'a, C>( self, db: &'a C, - ) -> impl Future, DbErr>> + 'a + ) -> Result, DbErr> where - C: ConnectionTrait, + C: ConnectionTrait<'b>, A: 'a, { // TODO: extract primary key's value from query @@ -46,7 +46,7 @@ where ); } } - Inserter::::new(query).exec(db) + Inserter::::new(query).exec(db).await // TODO: return primary key if extracted before, otherwise use InsertResult } } @@ -62,26 +62,26 @@ where } } - pub fn exec<'a, C>( + pub async fn exec<'a, 'b: 'a, C>( self, db: &'a C, - ) -> impl Future, DbErr>> + 'a + ) -> Result, DbErr> where - C: ConnectionTrait, + C: ConnectionTrait<'b>, A: 'a, { let builder = db.get_database_backend(); - exec_insert(builder.build(&self.query), db) + exec_insert(builder.build(&self.query), db).await } } // Only Statement impl Send -async fn exec_insert( +async fn exec_insert<'a, 'b: 'a, A, C>( statement: Statement, db: &C, ) -> Result, DbErr> where - C: ConnectionTrait, + C: ConnectionTrait<'b>, A: ActiveModelTrait, { type PrimaryKey = <::Entity as EntityTrait>::PrimaryKey; diff --git a/src/executor/paginator.rs b/src/executor/paginator.rs index 3dea86f82..b1db9342c 100644 --- a/src/executor/paginator.rs +++ b/src/executor/paginator.rs @@ -9,7 +9,7 @@ pub type PinBoxStream<'db, Item> = Pin + 'db>>; #[derive(Clone, Debug)] pub struct Paginator<'db, C, S> where - C: ConnectionTrait, + C: ConnectionTrait<'db>, S: SelectorTrait + 'db, { pub(crate) query: SelectStatement, @@ -23,7 +23,7 @@ where impl<'db, C, S> Paginator<'db, C, S> where - C: ConnectionTrait, + C: ConnectionTrait<'db>, S: SelectorTrait + 'db, { /// Fetch a specific page diff --git a/src/executor/select.rs b/src/executor/select.rs index 44d703af3..acca03691 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -103,35 +103,35 @@ where } } - pub async fn one(self, db: &C) -> Result, DbErr> - where C: ConnectionTrait { + pub async fn one<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { self.into_model().one(db).await } - pub async fn all(self, db: &C) -> Result, DbErr> - where C: ConnectionTrait { + pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { self.into_model().all(db).await } #[cfg(feature = "sqlx-dep")] pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result> + 'b, DbErr> where - C: ConnectionTrait, + C: ConnectionTrait<'a>, { self.into_model().stream(db).await } - pub fn paginate( + pub fn paginate<'a, C>( self, - db: &C, + db: &'a C, page_size: usize, - ) -> Paginator<'_, C, SelectModel> - where C: ConnectionTrait { + ) -> Paginator<'a, C, SelectModel> + where C: ConnectionTrait<'a> { self.into_model().paginate(db, page_size) } - pub async fn count(self, db: &C) -> Result - where C: ConnectionTrait { + pub async fn count<'a, C>(self, db: &'a C) -> Result + where C: ConnectionTrait<'a> { self.paginate(db, 1).num_items().await } } @@ -160,41 +160,41 @@ where } } - pub async fn one( + pub async fn one<'a, C>( self, db: &C, ) -> Result)>, DbErr> - where C: ConnectionTrait { + where C: ConnectionTrait<'a> { self.into_model().one(db).await } - pub async fn all( + pub async fn all<'a, C>( self, db: &C, ) -> Result)>, DbErr> - where C: ConnectionTrait { + where C: ConnectionTrait<'a> { self.into_model().all(db).await } #[cfg(feature = "sqlx-dep")] pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result), DbErr>> + 'b, DbErr> where - C: ConnectionTrait, + C: ConnectionTrait<'a>, { self.into_model().stream(db).await } - pub fn paginate( + pub fn paginate<'a, C>( self, - db: &C, + db: &'a C, page_size: usize, - ) -> Paginator<'_, C, SelectTwoModel> - where C: ConnectionTrait { + ) -> Paginator<'a, C, SelectTwoModel> + where C: ConnectionTrait<'a> { self.into_model().paginate(db, page_size) } - pub async fn count(self, db: &C) -> Result - where C: ConnectionTrait { + pub async fn count<'a, C>(self, db: &'a C) -> Result + where C: ConnectionTrait<'a> { self.paginate(db, 1).num_items().await } } @@ -223,27 +223,27 @@ where } } - pub async fn one( + pub async fn one<'a, C>( self, db: &C, ) -> Result)>, DbErr> - where C: ConnectionTrait { + where C: ConnectionTrait<'a> { self.into_model().one(db).await } #[cfg(feature = "sqlx-dep")] pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result), DbErr>> + 'b, DbErr> where - C: ConnectionTrait, + C: ConnectionTrait<'a>, { self.into_model().stream(db).await } - pub async fn all( + pub async fn all<'a, C>( self, db: &C, ) -> Result)>, DbErr> - where C: ConnectionTrait { + where C: ConnectionTrait<'a> { let rows = self.into_model().all(db).await?; Ok(consolidate_query_result::(rows)) } @@ -262,8 +262,8 @@ impl Selector where S: SelectorTrait, { - pub async fn one(mut self, db: &C) -> Result, DbErr> - where C: ConnectionTrait { + pub async fn one<'a, C>(mut self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { let builder = db.get_database_backend(); self.query.limit(1); let row = db.query_one(builder.build(&self.query)).await?; @@ -273,8 +273,8 @@ where } } - pub async fn all(self, db: &C) -> Result, DbErr> - where C: ConnectionTrait { + pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { let builder = db.get_database_backend(); let rows = db.query_all(builder.build(&self.query)).await?; let mut models = Vec::new(); @@ -287,7 +287,7 @@ where #[cfg(feature = "sqlx-dep")] pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result> + 'b>>, DbErr> where - C: ConnectionTrait, + C: ConnectionTrait<'a>, S: 'b, { let builder = db.get_database_backend(); @@ -297,8 +297,8 @@ where }))) } - pub fn paginate(self, db: &C, page_size: usize) -> Paginator<'_, C, S> - where C: ConnectionTrait { + pub fn paginate<'a, C>(self, db: &'a C, page_size: usize) -> Paginator<'a, C, S> + where C: ConnectionTrait<'a> { Paginator { query: self.query, page: 0, @@ -492,8 +492,8 @@ where /// ),] /// ); /// ``` - pub async fn one(self, db: &C) -> Result, DbErr> - where C: ConnectionTrait { + pub async fn one<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { let row = db.query_one(self.stmt).await?; match row { Some(row) => Ok(Some(S::from_raw_query_result(row)?)), @@ -532,8 +532,8 @@ where /// ),] /// ); /// ``` - pub async fn all(self, db: &C) -> Result, DbErr> - where C: ConnectionTrait { + pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { let rows = db.query_all(self.stmt).await?; let mut models = Vec::new(); for row in rows.into_iter() { diff --git a/src/executor/update.rs b/src/executor/update.rs index da0f5c401..3051d53e7 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -16,10 +16,10 @@ impl<'a, A: 'a> UpdateOne where A: ActiveModelTrait, { - pub fn exec(self, db: &'a C) -> impl Future> + 'a - where C: ConnectionTrait { + pub async fn exec<'b: 'a, C>(self, db: &'a C) -> Result + where C: ConnectionTrait<'b> { // so that self is dropped before entering await - exec_update_and_return_original(self.query, self.model, db) + exec_update_and_return_original(self.query, self.model, db).await } } @@ -31,7 +31,7 @@ where self, db: &'a C, ) -> impl Future> + 'a - where C: ConnectionTrait { + where C: ConnectionTrait<'a> { // so that self is dropped before entering await exec_update_only(self.query, db) } @@ -42,40 +42,40 @@ impl Updater { Self { query } } - pub fn exec( + pub async fn exec<'a, 'b: 'a, C>( self, - db: &C, - ) -> impl Future> + '_ - where C: ConnectionTrait { + db: &'a C, + ) -> Result + where C: ConnectionTrait<'b> { let builder = db.get_database_backend(); - exec_update(builder.build(&self.query), db) + exec_update(builder.build(&self.query), db).await } } -async fn exec_update_only( +async fn exec_update_only<'a, C>( query: UpdateStatement, - db: &C, + db: &'a C, ) -> Result -where C: ConnectionTrait { +where C: ConnectionTrait<'a> { Updater::new(query).exec(db).await } -async fn exec_update_and_return_original( +async fn exec_update_and_return_original<'a, 'b: 'a, A, C>( query: UpdateStatement, model: A, - db: &C, + db: &'a C, ) -> Result where A: ActiveModelTrait, - C: ConnectionTrait, + C: ConnectionTrait<'b>, { Updater::new(query).exec(db).await?; Ok(model) } // Only Statement impl Send -async fn exec_update(statement: Statement, db: &C) -> Result -where C: ConnectionTrait { +async fn exec_update<'a, 'b: 'a, C>(statement: Statement, db: &'a C) -> Result +where C: ConnectionTrait<'b> { let result = db.execute(statement).await?; Ok(UpdateResult { rows_affected: result.rows_affected(), diff --git a/tests/transaction_tests.rs b/tests/transaction_tests.rs index 33de12a5c..cc71de1d3 100644 --- a/tests/transaction_tests.rs +++ b/tests/transaction_tests.rs @@ -14,7 +14,7 @@ pub use sea_orm::{QueryFilter, ConnectionTrait}; pub async fn transaction() { let ctx = TestContext::new("transaction_test").await; - ctx.db.transaction::<_, (), DbErr>(|txn| Box::pin(async move { + ctx.db.transaction::<_, _, DbErr>(|txn| Box::pin(async move { let _ = bakery::ActiveModel { name: Set("SeaSide Bakery".to_owned()), profit_margin: Set(10.4), @@ -60,7 +60,7 @@ pub async fn transaction_with_reference() { ctx.delete().await; } -fn _transaction_with_reference<'a>(txn: &'a DatabaseTransaction, name1: &'a str, name2: &'a str, search_name: &'a str) -> std::pin::Pin> + Send + 'a>> { +fn _transaction_with_reference<'a>(txn: &'a DatabaseTransaction<'_>, name1: &'a str, name2: &'a str, search_name: &'a str) -> std::pin::Pin> + Send + 'a>> { Box::pin(async move { let _ = bakery::ActiveModel { name: Set(name1.to_owned()), From 88214f9351dfd646c0e45f2a31b769e2e8ddb67a Mon Sep 17 00:00:00 2001 From: Marco Napetti Date: Sun, 26 Sep 2021 23:01:48 +0200 Subject: [PATCH 06/12] Less radioactive unsafe --- src/database/connection.rs | 8 ++++---- src/database/db_connection.rs | 10 +++++----- src/database/db_transaction.rs | 20 ++++++++++---------- src/driver/sqlx_mysql.rs | 6 +++--- src/driver/sqlx_postgres.rs | 6 +++--- src/driver/sqlx_sqlite.rs | 6 +++--- src/entity/active_model.rs | 2 +- tests/transaction_tests.rs | 2 +- 8 files changed, 30 insertions(+), 30 deletions(-) diff --git a/src/database/connection.rs b/src/database/connection.rs index 349694486..1a5580441 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -52,7 +52,7 @@ impl std::fmt::Debug for DatabaseConnection { } } -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl<'a> ConnectionTrait<'a> for DatabaseConnection { fn get_database_backend(&self) -> DbBackend { match self { @@ -142,11 +142,11 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. async fn transaction(&'a self, _callback: F) -> Result> where - F: for<'c> FnOnce(&'c DatabaseTransaction<'a>) -> Pin> + Send + 'c>> + Send + Sync, + F: for<'c> FnOnce(&'c DatabaseTransaction<'a>) -> Pin> + 'c>>, // F: FnOnce(&DatabaseTransaction<'a>) -> Fut + Send, // Fut: Future> + Send, - T: Send, - E: std::error::Error + Send, + // T: Send, + E: std::error::Error, { match self { #[cfg(feature = "sqlx-mysql")] diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 0a531f174..323f3442e 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -15,8 +15,8 @@ pub(crate) enum InnerConnection<'a> { Transaction(Box<&'a DatabaseTransaction<'a>>), } -#[async_trait::async_trait] -pub trait ConnectionTrait<'a>: Sync { +#[async_trait::async_trait(?Send)] +pub trait ConnectionTrait<'a> { fn get_database_backend(&self) -> DbBackend; async fn execute(&self, stmt: Statement) -> Result; @@ -33,11 +33,11 @@ pub trait ConnectionTrait<'a>: Sync { /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. async fn transaction(&'a self, callback: F) -> Result> where - F: for<'c> FnOnce(&'c DatabaseTransaction<'a>) -> Pin> + Send + 'c>> + Send + Sync, + F: for<'c> FnOnce(&'c DatabaseTransaction<'a>) -> Pin> + 'c>>, // F: FnOnce(&DatabaseTransaction<'a>) -> Fut + Send, // Fut: Future> + Send, - T: Send, - E: std::error::Error + Send; + // T: Send, + E: std::error::Error; fn is_mock_connection(&self) -> bool { false diff --git a/src/database/db_transaction.rs b/src/database/db_transaction.rs index 419bb6a9e..cfd3f666b 100644 --- a/src/database/db_transaction.rs +++ b/src/database/db_transaction.rs @@ -68,11 +68,11 @@ impl<'a> DatabaseTransaction<'a> { pub(crate) async fn run(self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + Send + 'b>> + Send + Sync, + F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + 'b>>, // F: FnOnce(&DatabaseTransaction<'a>) -> Fut + Send, // Fut: Future> + Send, - T: Send, - E: std::error::Error + Send, + // T: Send, + E: std::error::Error, { let res = callback(&self).await.map_err(|e| TransactionError::Transaction(e)); if res.is_ok() { @@ -108,7 +108,7 @@ impl<'a> DatabaseTransaction<'a> { } // non destructive commit - fn ref_commit(&'a self) -> Pin> + Send + 'a>> { + fn ref_commit(&'a self) -> Pin> + 'a>> { Box::pin(async move { if self.conn.is_some() { match self.get_conn() { @@ -158,7 +158,7 @@ impl<'a> DatabaseTransaction<'a> { } // non destructive rollback - fn ref_rollback(&'a self) -> Pin> + Send + 'a>> { + fn ref_rollback(&'a self) -> Pin> + 'a>> { Box::pin(async move { if self.conn.is_some() { match self.get_conn() { @@ -255,9 +255,9 @@ impl<'a> Drop for DatabaseTransaction<'a> { } // this is needed since sqlite connections aren't sync -unsafe impl<'a> Sync for DatabaseTransaction<'a> {} +// unsafe impl<'a> Sync for DatabaseTransaction<'a> {} -#[async_trait::async_trait] +#[async_trait::async_trait(?Send)] impl<'a> ConnectionTrait<'a> for DatabaseTransaction<'a> { fn get_database_backend(&self) -> DbBackend { match self.conn.as_ref().map(|c| unsafe { &*c.get() }) { @@ -381,11 +381,11 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction<'a> { /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. async fn transaction(&'a self, _callback: F) -> Result> where - F: for<'c> FnOnce(&'c DatabaseTransaction<'a>) -> Pin> + Send + 'c>> + Send + Sync, + F: for<'c> FnOnce(&'c DatabaseTransaction<'a>) -> Pin> + 'c>>, // F: FnOnce(&DatabaseTransaction<'a>) -> Fut + Send, // Fut: Future> + Send, - T: Send, - E: std::error::Error + Send, + // T: Send, + E: std::error::Error, { let transaction = self.begin().await.map_err(|e| TransactionError::Connection(e))?; transaction.run(_callback).await diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index e6189392a..76740e024 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -115,11 +115,11 @@ impl SqlxMySqlPoolConnection { pub async fn transaction<'a, F, T, E/*, Fut*/>(&'a self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + Send + 'b>> + Send + Sync, + F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + 'b>>, // F: FnOnce(&DatabaseTransaction<'_>) -> Fut + Send, // Fut: Future> + Send, - T: Send, - E: std::error::Error + Send, + // T: Send, + E: std::error::Error, { if let Ok(conn) = self.pool.acquire().await { let transaction = DatabaseTransaction::new_mysql(conn).await.map_err(|e| TransactionError::Connection(e))?; diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index 0e4a7dda1..a1c9d5720 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -115,11 +115,11 @@ impl SqlxPostgresPoolConnection { pub async fn transaction<'a, F, T, E/*, Fut*/>(&'a self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + Send + 'b>> + Send + Sync, + F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + 'b>>, // F: FnOnce(&DatabaseTransaction<'_>) -> Fut + Send, // Fut: Future> + Send, - T: Send, - E: std::error::Error + Send, + // T: Send, + E: std::error::Error, { if let Ok(conn) = self.pool.acquire().await { let transaction = DatabaseTransaction::new_postgres(conn).await.map_err(|e| TransactionError::Connection(e))?; diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index 9b168a6d9..e8d6f47c8 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -115,11 +115,11 @@ impl SqlxSqlitePoolConnection { pub async fn transaction<'a, F, T, E/*, Fut*/>(&'a self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + Send + 'b>> + Send + Sync, + F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + 'b>>, // F: FnOnce(&DatabaseTransaction<'_>) -> Fut + Send, // Fut: Future> + Send, - T: Send, - E: std::error::Error + Send, + // T: Send, + E: std::error::Error, { if let Ok(conn) = self.pool.acquire().await { let transaction = DatabaseTransaction::new_sqlite(conn).await.map_err(|e| TransactionError::Connection(e))?; diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index 752e6f088..aa2b0737d 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -51,7 +51,7 @@ where ActiveValue::unchanged(value) } -#[async_trait] +#[async_trait(?Send)] pub trait ActiveModelTrait: Clone + Debug { type Entity: EntityTrait; diff --git a/tests/transaction_tests.rs b/tests/transaction_tests.rs index cc71de1d3..e0998940c 100644 --- a/tests/transaction_tests.rs +++ b/tests/transaction_tests.rs @@ -60,7 +60,7 @@ pub async fn transaction_with_reference() { ctx.delete().await; } -fn _transaction_with_reference<'a>(txn: &'a DatabaseTransaction<'_>, name1: &'a str, name2: &'a str, search_name: &'a str) -> std::pin::Pin> + Send + 'a>> { +fn _transaction_with_reference<'a>(txn: &'a DatabaseTransaction<'_>, name1: &'a str, name2: &'a str, search_name: &'a str) -> std::pin::Pin> + 'a>> { Box::pin(async move { let _ = bakery::ActiveModel { name: Set(name1.to_owned()), From 089ebd751457fac77635d288d93c000a1de06b17 Mon Sep 17 00:00:00 2001 From: Marco Napetti Date: Wed, 29 Sep 2021 10:48:43 +0200 Subject: [PATCH 07/12] Mutex transaction --- Cargo.toml | 2 +- src/database/connection.rs | 52 ++-- src/database/db_connection.rs | 28 +-- src/database/db_transaction.rs | 263 ++++++-------------- src/database/mock.rs | 4 +- src/database/stream/mod.rs | 5 + src/database/{stream.rs => stream/query.rs} | 35 +-- src/database/stream/transaction.rs | 79 ++++++ src/driver/mock.rs | 4 +- src/driver/sqlx_mysql.rs | 14 +- src/driver/sqlx_postgres.rs | 14 +- src/driver/sqlx_sqlite.rs | 14 +- src/entity/active_model.rs | 16 +- src/executor/insert.rs | 12 +- src/executor/update.rs | 14 +- tests/stream_tests.rs | 2 +- tests/transaction_tests.rs | 2 +- 17 files changed, 261 insertions(+), 299 deletions(-) create mode 100644 src/database/stream/mod.rs rename src/database/{stream.rs => stream/query.rs} (71%) create mode 100644 src/database/stream/transaction.rs diff --git a/Cargo.toml b/Cargo.toml index abe43062a..eeba39c9a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1", optional = true } sqlx = { version = "^0.5", optional = true } uuid = { version = "0.8", features = ["serde", "v4"], optional = true } -ouroboros = "0.10" +ouroboros = "0.11" [dev-dependencies] smol = { version = "^1.2" } diff --git a/src/database/connection.rs b/src/database/connection.rs index 1a5580441..b1ba1e7cb 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -1,5 +1,5 @@ -use std::{future::Future, pin::Pin}; -use crate::{DatabaseTransaction, ConnectionTrait, ExecResult, QueryResult, QueryStream, Statement, StatementBuilder, TransactionError, error::*}; +use std::{future::Future, pin::Pin, sync::Arc}; +use crate::{DatabaseTransaction, ConnectionTrait, ExecResult, QueryResult, Statement, StatementBuilder, TransactionError, error::*}; use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; #[cfg_attr(not(feature = "mock"), derive(Clone))] @@ -11,7 +11,7 @@ pub enum DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection), #[cfg(feature = "mock")] - MockDatabaseConnection(crate::MockDatabaseConnection), + MockDatabaseConnection(Arc), Disconnected, } @@ -52,8 +52,10 @@ impl std::fmt::Debug for DatabaseConnection { } } -#[async_trait::async_trait(?Send)] +#[async_trait::async_trait] impl<'a> ConnectionTrait<'a> for DatabaseConnection { + type Stream = crate::QueryStream; + fn get_database_backend(&self) -> DbBackend { match self { #[cfg(feature = "sqlx-mysql")] @@ -110,21 +112,23 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { } } - async fn stream(&'a self, stmt: Statement) -> Result, DbErr> { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await, - #[cfg(feature = "sqlx-postgres")] - DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await, - #[cfg(feature = "sqlx-sqlite")] - DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await, - #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => Ok(QueryStream::from((conn, stmt))), - DatabaseConnection::Disconnected => panic!("Disconnected"), - } + fn stream(&'a self, stmt: Statement) -> Pin> + 'a>> { + Box::pin(async move { + Ok(match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => crate::QueryStream::from((Arc::clone(conn), stmt)), + DatabaseConnection::Disconnected => panic!("Disconnected"), + }) + }) } - async fn begin(&'a self) -> Result, DbErr> { + async fn begin(&self) -> Result { match self { #[cfg(feature = "sqlx-mysql")] DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await, @@ -133,20 +137,18 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => DatabaseTransaction::new_mock(conn).await, + DatabaseConnection::MockDatabaseConnection(conn) => DatabaseTransaction::new_mock(Arc::clone(conn)).await, DatabaseConnection::Disconnected => panic!("Disconnected"), } } /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. - async fn transaction(&'a self, _callback: F) -> Result> + async fn transaction(&self, _callback: F) -> Result> where - F: for<'c> FnOnce(&'c DatabaseTransaction<'a>) -> Pin> + 'c>>, - // F: FnOnce(&DatabaseTransaction<'a>) -> Fut + Send, - // Fut: Future> + Send, - // T: Send, - E: std::error::Error, + F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin> + Send + 'c>> + Send, + T: Send, + E: std::error::Error + Send, { match self { #[cfg(feature = "sqlx-mysql")] @@ -157,7 +159,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await, #[cfg(feature = "mock")] DatabaseConnection::MockDatabaseConnection(conn) => { - let transaction = DatabaseTransaction::new_mock(conn).await.map_err(|e| TransactionError::Connection(e))?; + let transaction = DatabaseTransaction::new_mock(Arc::clone(conn)).await.map_err(|e| TransactionError::Connection(e))?; transaction.run(_callback).await }, DatabaseConnection::Disconnected => panic!("Disconnected"), diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 323f3442e..391a40d5d 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -1,9 +1,10 @@ -use std::{future::Future, pin::Pin}; -use crate::{DatabaseTransaction, DbBackend, DbErr, ExecResult, MockDatabaseConnection, QueryResult, QueryStream, Statement, TransactionError}; +use std::{future::Future, pin::Pin, sync::Arc}; +use crate::{DatabaseTransaction, DbBackend, DbErr, ExecResult, MockDatabaseConnection, QueryResult, Statement, TransactionError}; +use futures::Stream; #[cfg(feature = "sqlx-dep")] use sqlx::pool::PoolConnection; -pub(crate) enum InnerConnection<'a> { +pub(crate) enum InnerConnection { #[cfg(feature = "sqlx-mysql")] MySql(PoolConnection), #[cfg(feature = "sqlx-postgres")] @@ -11,12 +12,13 @@ pub(crate) enum InnerConnection<'a> { #[cfg(feature = "sqlx-sqlite")] Sqlite(PoolConnection), #[cfg(feature = "mock")] - Mock(&'a MockDatabaseConnection), - Transaction(Box<&'a DatabaseTransaction<'a>>), + Mock(Arc), } -#[async_trait::async_trait(?Send)] +#[async_trait::async_trait] pub trait ConnectionTrait<'a> { + type Stream: Stream>; + fn get_database_backend(&self) -> DbBackend; async fn execute(&self, stmt: Statement) -> Result; @@ -25,19 +27,17 @@ pub trait ConnectionTrait<'a> { async fn query_all(&self, stmt: Statement) -> Result, DbErr>; - async fn stream(&'a self, stmt: Statement) -> Result, DbErr>; + fn stream(&'a self, stmt: Statement) -> Pin> + 'a>>; - async fn begin(&'a self) -> Result, DbErr>; + async fn begin(&self) -> Result; /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. - async fn transaction(&'a self, callback: F) -> Result> + async fn transaction(&self, callback: F) -> Result> where - F: for<'c> FnOnce(&'c DatabaseTransaction<'a>) -> Pin> + 'c>>, - // F: FnOnce(&DatabaseTransaction<'a>) -> Fut + Send, - // Fut: Future> + Send, - // T: Send, - E: std::error::Error; + F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin> + Send + 'c>> + Send, + T: Send, + E: std::error::Error + Send; fn is_mock_connection(&self) -> bool { false diff --git a/src/database/db_transaction.rs b/src/database/db_transaction.rs index cfd3f666b..fba8b857a 100644 --- a/src/database/db_transaction.rs +++ b/src/database/db_transaction.rs @@ -1,78 +1,77 @@ -use std::{cell::UnsafeCell, future::Future, pin::Pin}; -use crate::{ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult, QueryStream, Statement, debug_print}; -use futures::{Stream, TryStreamExt}; +use std::{sync::Arc, future::Future, pin::Pin}; +use crate::{ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult, Statement, TransactionStream, debug_print}; +use futures::lock::Mutex; #[cfg(feature = "sqlx-dep")] use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err}; #[cfg(feature = "sqlx-dep")] -use sqlx::{pool::PoolConnection, Executor, TransactionManager}; +use sqlx::{pool::PoolConnection, TransactionManager}; // a Transaction is just a sugar for a connection where START TRANSACTION has been executed -pub struct DatabaseTransaction<'a> { - // using Option we don't even need an "open" flag - conn: Option>>, +pub struct DatabaseTransaction { + conn: Arc>, + backend: DbBackend, + open: bool, } -impl<'a> std::fmt::Debug for DatabaseTransaction<'a> { +impl std::fmt::Debug for DatabaseTransaction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "DatabaseTransaction") } } -impl<'a> DatabaseTransaction<'a> { +impl DatabaseTransaction { #[cfg(feature = "sqlx-mysql")] - pub(crate) async fn new_mysql(inner: PoolConnection) -> Result, DbErr> { - Self::build(InnerConnection::MySql(inner)).await + pub(crate) async fn new_mysql(inner: PoolConnection) -> Result { + Self::build(Arc::new(Mutex::new(InnerConnection::MySql(inner))), DbBackend::MySql).await } #[cfg(feature = "sqlx-postgres")] - pub(crate) async fn new_postgres(inner: PoolConnection) -> Result, DbErr> { - Self::build(InnerConnection::Postgres(inner)).await + pub(crate) async fn new_postgres(inner: PoolConnection) -> Result { + Self::build(Arc::new(Mutex::new(InnerConnection::Postgres(inner))), DbBackend::Postgres).await } #[cfg(feature = "sqlx-sqlite")] - pub(crate) async fn new_sqlite(inner: PoolConnection) -> Result, DbErr> { - Self::build(InnerConnection::Sqlite(inner)).await + pub(crate) async fn new_sqlite(inner: PoolConnection) -> Result { + Self::build(Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), DbBackend::Sqlite).await } #[cfg(feature = "mock")] - pub(crate) async fn new_mock(inner: &'a crate::MockDatabaseConnection) -> Result, DbErr> { - Self::build(InnerConnection::Mock(inner)).await + pub(crate) async fn new_mock(inner: Arc) -> Result { + let backend = inner.get_database_backend(); + Self::build(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await } - async fn build(conn: InnerConnection<'a>) -> Result, DbErr> { - let mut res = DatabaseTransaction { - conn: Some(UnsafeCell::new(conn)), + async fn build(conn: Arc>, backend: DbBackend) -> Result { + let res = DatabaseTransaction { + conn, + backend, + open: true, }; - match res.conn.as_mut().map(|c| c.get_mut()) { + match *res.conn.lock().await { #[cfg(feature = "sqlx-mysql")] - Some(InnerConnection::MySql(c)) => { + InnerConnection::MySql(ref mut c) => { ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? }, #[cfg(feature = "sqlx-postgres")] - Some(InnerConnection::Postgres(c)) => { + InnerConnection::Postgres(ref mut c) => { ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? }, #[cfg(feature = "sqlx-sqlite")] - Some(InnerConnection::Sqlite(c)) => { + InnerConnection::Sqlite(ref mut c) => { ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? }, // should we do something for mocked connections? #[cfg(feature = "mock")] - Some(InnerConnection::Mock(_)) => {}, - // nested transactions should already have been started - Some(InnerConnection::Transaction(_)) => {}, - _ => unreachable!(), + InnerConnection::Mock(_) => {}, } Ok(res) } - pub(crate) async fn run(self, callback: F) -> Result> + pub(crate) async fn run(self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + 'b>>, - // F: FnOnce(&DatabaseTransaction<'a>) -> Fut + Send, - // Fut: Future> + Send, - // T: Send, - E: std::error::Error, + F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, + T: Send, + E: std::error::Error + Send, { let res = callback(&self).await.map_err(|e| TransactionError::Transaction(e)); if res.is_ok() { @@ -85,199 +84,94 @@ impl<'a> DatabaseTransaction<'a> { } pub async fn commit(mut self) -> Result<(), DbErr> { - match self.conn.take().map(|c| c.into_inner()) { + self.open = false; + match *self.conn.lock().await { #[cfg(feature = "sqlx-mysql")] - Some(InnerConnection::MySql(ref mut c)) => { + InnerConnection::MySql(ref mut c) => { ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? }, #[cfg(feature = "sqlx-postgres")] - Some(InnerConnection::Postgres(ref mut c)) => { + InnerConnection::Postgres(ref mut c) => { ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? }, #[cfg(feature = "sqlx-sqlite")] - Some(InnerConnection::Sqlite(ref mut c)) => { + InnerConnection::Sqlite(ref mut c) => { ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? }, - Some(InnerConnection::Transaction(c)) => c.ref_commit().await?, - //Should we do something for mocked &connections? + //Should we do something for mocked connections? #[cfg(feature = "mock")] - Some(InnerConnection::Mock(_)) => {}, - _ => unreachable!(), + InnerConnection::Mock(_) => {}, } Ok(()) } - // non destructive commit - fn ref_commit(&'a self) -> Pin> + 'a>> { - Box::pin(async move { - if self.conn.is_some() { - match self.get_conn() { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(ref mut c) => { - ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? - }, - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(ref mut c) => { - ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? - }, - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(ref mut c) => { - ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? - }, - InnerConnection::Transaction(c) => c.ref_commit().await?, - //Should we do something for mocked &connections? - #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {}, - } - } - Ok(()) - }) - } - pub async fn rollback(mut self) -> Result<(), DbErr> { - match self.conn.take().map(|c| c.into_inner()) { + self.open = false; + match *self.conn.lock().await { #[cfg(feature = "sqlx-mysql")] - Some(InnerConnection::MySql(ref mut c)) => { + InnerConnection::MySql(ref mut c) => { ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? }, #[cfg(feature = "sqlx-postgres")] - Some(InnerConnection::Postgres(ref mut c)) => { + InnerConnection::Postgres(ref mut c) => { ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? }, #[cfg(feature = "sqlx-sqlite")] - Some(InnerConnection::Sqlite(ref mut c)) => { + InnerConnection::Sqlite(ref mut c) => { ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? }, - Some(InnerConnection::Transaction(c)) => c.ref_rollback().await?, - //Should we do something for mocked &connections? + //Should we do something for mocked connections? #[cfg(feature = "mock")] - Some(InnerConnection::Mock(_)) => {}, - _ => unreachable!(), + InnerConnection::Mock(_) => {}, } Ok(()) } - // non destructive rollback - fn ref_rollback(&'a self) -> Pin> + 'a>> { - Box::pin(async move { - if self.conn.is_some() { - match self.get_conn() { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(ref mut c) => { - ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? - }, - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(ref mut c) => { - ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? - }, - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(ref mut c) => { - ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? - }, - InnerConnection::Transaction(c) => c.ref_rollback().await?, - //Should we do something for mocked &connections? - #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {}, - } - } - Ok(()) - }) - } - - pub(crate) fn fetch<'b>(&'b self, stmt: &'b Statement) -> Pin> + 'b>> { - match self.get_conn() { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(inner) => { - let query = crate::driver::sqlx_mysql::sqlx_query(stmt); - Box::pin(inner.fetch(query) - .map_ok(Into::into) - .map_err(sqlx_error_to_query_err)) - }, - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(inner) => { - let query = crate::driver::sqlx_postgres::sqlx_query(stmt); - Box::pin(inner.fetch(query) - .map_ok(Into::into) - .map_err(sqlx_error_to_query_err)) - }, - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(inner) => { - let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); - Box::pin(inner.fetch(query) - .map_ok(Into::into) - .map_err(sqlx_error_to_query_err)) - }, - InnerConnection::Transaction(inner) => { - inner.fetch(stmt) - }, - #[cfg(feature = "mock")] - InnerConnection::Mock(inner) => { - inner.fetch(stmt) - }, - } - } - // the rollback is queued and will be performed on next async operation, like returning the connection to the pool - fn start_rollback(&self) { - if let Some(conn) = self.conn.as_ref() { - match unsafe { &mut *conn.get() } { + fn start_rollback(&mut self) { + if self.open { + match Arc::get_mut(&mut self.conn).map(|o| o.get_mut()) { #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(c) => { + Some(InnerConnection::MySql(c)) => { ::TransactionManager::start_rollback(c); }, #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(c) => { + Some(InnerConnection::Postgres(c)) => { ::TransactionManager::start_rollback(c); }, #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(c) => { + Some(InnerConnection::Sqlite(c)) => { ::TransactionManager::start_rollback(c); }, - InnerConnection::Transaction(c) => { - c.start_rollback(); - } - //Should we do something for mocked &connections? + //Should we do something for mocked connections? #[cfg(feature = "mock")] - InnerConnection::Mock(_) => {}, + Some(InnerConnection::Mock(_)) => {}, + //this happens if this is a nested transaction + None => unreachable!(), } } } - - fn get_conn(&self) -> &mut InnerConnection<'a> { - unsafe { &mut *self.conn.as_ref().map(|c| c.get()).unwrap() } - } } -impl<'a> Drop for DatabaseTransaction<'a> { +impl Drop for DatabaseTransaction { fn drop(&mut self) { self.start_rollback(); } } -// this is needed since sqlite connections aren't sync -// unsafe impl<'a> Sync for DatabaseTransaction<'a> {} +#[async_trait::async_trait] +impl<'a> ConnectionTrait<'a> for DatabaseTransaction { + type Stream = TransactionStream<'a>; -#[async_trait::async_trait(?Send)] -impl<'a> ConnectionTrait<'a> for DatabaseTransaction<'a> { fn get_database_backend(&self) -> DbBackend { - match self.conn.as_ref().map(|c| unsafe { &*c.get() }) { - #[cfg(feature = "sqlx-mysql")] - Some(InnerConnection::MySql(_)) => DbBackend::MySql, - #[cfg(feature = "sqlx-postgres")] - Some(InnerConnection::Postgres(_)) => DbBackend::Postgres, - #[cfg(feature = "sqlx-sqlite")] - Some(InnerConnection::Sqlite(_)) => DbBackend::Sqlite, - #[cfg(feature = "mock")] - Some(InnerConnection::Mock(c)) => c.get_database_backend(), - Some(InnerConnection::Transaction(c)) => c.get_database_backend(), - _ => unreachable!(), - } + // this way we don't need to lock + self.backend } async fn execute(&self, stmt: Statement) -> Result { debug_print!("{}", stmt); - let _res = match self.get_conn() { + let _res = match &mut *self.conn.lock().await { #[cfg(feature = "sqlx-mysql")] InnerConnection::MySql(conn) => { let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); @@ -298,7 +192,6 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction<'a> { }, #[cfg(feature = "mock")] InnerConnection::Mock(conn) => return conn.execute(stmt), - InnerConnection::Transaction(conn) => return conn.execute(stmt).await, }; #[cfg(feature = "sqlx-dep")] _res.map_err(sqlx_error_to_exec_err) @@ -307,7 +200,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction<'a> { async fn query_one(&self, stmt: Statement) -> Result, DbErr> { debug_print!("{}", stmt); - let _res = match self.get_conn() { + let _res = match &mut *self.conn.lock().await { #[cfg(feature = "sqlx-mysql")] InnerConnection::MySql(conn) => { let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); @@ -322,13 +215,12 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction<'a> { }, #[cfg(feature = "sqlx-sqlite")] InnerConnection::Sqlite(conn) => { - let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); + let query= crate::driver::sqlx_sqlite::sqlx_query(&stmt); query.fetch_one(conn).await .map(|row| Some(row.into())) }, #[cfg(feature = "mock")] InnerConnection::Mock(conn) => return conn.query_one(stmt), - InnerConnection::Transaction(conn) => return conn.query_one(stmt).await, }; #[cfg(feature = "sqlx-dep")] if let Err(sqlx::Error::RowNotFound) = _res { @@ -342,7 +234,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction<'a> { async fn query_all(&self, stmt: Statement) -> Result, DbErr> { debug_print!("{}", stmt); - let _res = match self.get_conn() { + let _res = match &mut *self.conn.lock().await { #[cfg(feature = "sqlx-mysql")] InnerConnection::MySql(conn) => { let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); @@ -363,29 +255,28 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction<'a> { }, #[cfg(feature = "mock")] InnerConnection::Mock(conn) => return conn.query_all(stmt), - InnerConnection::Transaction(conn) => return conn.query_all(stmt).await, }; #[cfg(feature = "sqlx-dep")] _res.map_err(sqlx_error_to_query_err) } - async fn stream(&'a self, stmt: Statement) -> Result, DbErr> { - Ok(QueryStream::from((self, stmt))) + fn stream(&'a self, stmt: Statement) -> Pin> + 'a>> { + Box::pin(async move { + Ok(crate::TransactionStream::build(self.conn.lock().await, stmt).await) + }) } - async fn begin(&'a self) -> Result, DbErr> { - DatabaseTransaction::build(InnerConnection::Transaction(Box::new(self))).await + async fn begin(&self) -> Result { + DatabaseTransaction::build(Arc::clone(&self.conn), self.backend).await } /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. - async fn transaction(&'a self, _callback: F) -> Result> + async fn transaction(&self, _callback: F) -> Result> where - F: for<'c> FnOnce(&'c DatabaseTransaction<'a>) -> Pin> + 'c>>, - // F: FnOnce(&DatabaseTransaction<'a>) -> Fut + Send, - // Fut: Future> + Send, - // T: Send, - E: std::error::Error, + F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin> + Send + 'c>> + Send, + T: Send, + E: std::error::Error + Send, { let transaction = self.begin().await.map_err(|e| TransactionError::Connection(e))?; transaction.run(_callback).await diff --git a/src/database/mock.rs b/src/database/mock.rs index 7077a6330..d0862e807 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -4,7 +4,7 @@ use crate::{ Statement, Transaction, }; use sea_query::{Value, ValueType}; -use std::collections::BTreeMap; +use std::{collections::BTreeMap, sync::Arc}; #[derive(Debug)] pub struct MockDatabase { @@ -53,7 +53,7 @@ impl MockDatabase { } pub fn into_connection(self) -> DatabaseConnection { - DatabaseConnection::MockDatabaseConnection(MockDatabaseConnection::new(self)) + DatabaseConnection::MockDatabaseConnection(Arc::new(MockDatabaseConnection::new(self))) } pub fn append_exec_results(mut self, mut vec: Vec) -> Self { diff --git a/src/database/stream/mod.rs b/src/database/stream/mod.rs new file mode 100644 index 000000000..774cf45fa --- /dev/null +++ b/src/database/stream/mod.rs @@ -0,0 +1,5 @@ +mod query; +mod transaction; + +pub use query::*; +pub use transaction::*; diff --git a/src/database/stream.rs b/src/database/stream/query.rs similarity index 71% rename from src/database/stream.rs rename to src/database/stream/query.rs index fdc83ca17..0bc4dd6c8 100644 --- a/src/database/stream.rs +++ b/src/database/stream/query.rs @@ -1,62 +1,56 @@ -use std::{pin::Pin, task::Poll}; +use std::{pin::Pin, task::Poll, sync::Arc}; use futures::{Stream, TryStreamExt}; use sqlx::{pool::PoolConnection, Executor}; -use crate::{DatabaseTransaction, DbErr, InnerConnection, QueryResult, Statement, sqlx_error_to_query_err}; +use crate::{DbErr, InnerConnection, QueryResult, Statement, sqlx_error_to_query_err}; #[ouroboros::self_referencing] -pub struct QueryStream<'a> { +pub struct QueryStream { stmt: Statement, - conn: InnerConnection<'a>, + conn: InnerConnection, #[borrows(mut conn, stmt)] #[not_covariant] stream: Pin> + 'this>>, } #[cfg(feature = "sqlx-mysql")] -impl<'a> From<(PoolConnection, Statement)> for QueryStream<'a> { +impl From<(PoolConnection, Statement)> for QueryStream { fn from((conn, stmt): (PoolConnection, Statement)) -> Self { QueryStream::build(stmt, InnerConnection::MySql(conn)) } } #[cfg(feature = "sqlx-postgres")] -impl<'a> From<(PoolConnection, Statement)> for QueryStream<'a> { +impl From<(PoolConnection, Statement)> for QueryStream { fn from((conn, stmt): (PoolConnection, Statement)) -> Self { QueryStream::build(stmt, InnerConnection::Postgres(conn)) } } #[cfg(feature = "sqlx-sqlite")] -impl<'a> From<(PoolConnection, Statement)> for QueryStream<'a> { +impl From<(PoolConnection, Statement)> for QueryStream { fn from((conn, stmt): (PoolConnection, Statement)) -> Self { QueryStream::build(stmt, InnerConnection::Sqlite(conn)) } } #[cfg(feature = "mock")] -impl<'a> From<(&'a crate::MockDatabaseConnection, Statement)> for QueryStream<'a> { - fn from((conn, stmt): (&'a crate::MockDatabaseConnection, Statement)) -> Self { +impl From<(Arc, Statement)> for QueryStream { + fn from((conn, stmt): (Arc, Statement)) -> Self { QueryStream::build(stmt, InnerConnection::Mock(conn)) } } -impl<'a> From<(&'a DatabaseTransaction<'a>, Statement)> for QueryStream<'a> { - fn from((conn, stmt): (&'a DatabaseTransaction<'a>, Statement)) -> Self { - QueryStream::build(stmt, InnerConnection::Transaction(Box::new(conn))) - } -} - -impl<'a> std::fmt::Debug for QueryStream<'a> { +impl std::fmt::Debug for QueryStream { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "QueryStream") } } -impl<'a> QueryStream<'a> { - fn build(stmt: Statement, conn: InnerConnection<'a>) -> QueryStream<'a> { +impl QueryStream { + fn build(stmt: Statement, conn: InnerConnection) -> QueryStream { QueryStreamBuilder { stmt, conn, @@ -93,16 +87,13 @@ impl<'a> QueryStream<'a> { InnerConnection::Mock(c) => { c.fetch(stmt) }, - InnerConnection::Transaction(c) => { - c.fetch(stmt) - }, } }, }.build() } } -impl<'a> Stream for QueryStream<'a> { +impl Stream for QueryStream { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { diff --git a/src/database/stream/transaction.rs b/src/database/stream/transaction.rs new file mode 100644 index 000000000..30ec810ba --- /dev/null +++ b/src/database/stream/transaction.rs @@ -0,0 +1,79 @@ +use std::{ops::DerefMut, pin::Pin, task::Poll}; + +use futures::{Stream, TryStreamExt}; + +use sqlx::Executor; + +use futures::lock::MutexGuard; + +use crate::{DbErr, InnerConnection, QueryResult, Statement, sqlx_error_to_query_err}; + +#[ouroboros::self_referencing] +pub struct TransactionStream<'a> { + stmt: Statement, + conn: MutexGuard<'a, InnerConnection>, + #[borrows(mut conn, stmt)] + #[not_covariant] + stream: Pin> + 'this>>, +} + +impl<'a> std::fmt::Debug for TransactionStream<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TransactionStream") + } +} + +impl<'a> TransactionStream<'a> { + pub(crate) async fn build(conn: MutexGuard<'a, InnerConnection>, stmt: Statement) -> TransactionStream<'a> { + TransactionStreamAsyncBuilder { + stmt, + conn, + stream_builder: |conn, stmt| Box::pin(async move { + match conn.deref_mut() { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + let query = crate::driver::sqlx_mysql::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(sqlx_error_to_query_err) + ) as Pin>>> + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + let query = crate::driver::sqlx_postgres::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(sqlx_error_to_query_err) + ) as Pin>>> + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(sqlx_error_to_query_err) + ) as Pin>>> + }, + #[cfg(feature = "mock")] + InnerConnection::Mock(c) => { + c.fetch(stmt) + }, + } + }), + }.build().await + } +} + +impl<'a> Stream for TransactionStream<'a> { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + this.with_stream_mut(|stream| { + stream.as_mut().poll_next(cx) + }) + } +} diff --git a/src/driver/mock.rs b/src/driver/mock.rs index 4aeb024e6..5e4cc84ef 100644 --- a/src/driver/mock.rs +++ b/src/driver/mock.rs @@ -2,7 +2,7 @@ use crate::{ debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult, Statement, Transaction, }; -use std::{fmt::Debug, pin::Pin, sync::{ +use std::{fmt::Debug, pin::Pin, sync::{Arc, atomic::{AtomicUsize, Ordering}, Mutex, }}; @@ -50,7 +50,7 @@ impl MockDatabaseConnector { macro_rules! connect_mock_db { ( $syntax: expr ) => { Ok(DatabaseConnection::MockDatabaseConnection( - MockDatabaseConnection::new(MockDatabase::new($syntax)), + Arc::new(MockDatabaseConnection::new(MockDatabase::new($syntax))), )) }; } diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index 76740e024..75e6e5ff1 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -91,7 +91,7 @@ impl SqlxMySqlPoolConnection { } } - pub async fn stream(&self, stmt: Statement) -> Result, DbErr> { + pub async fn stream(&self, stmt: Statement) -> Result { debug_print!("{}", stmt); if let Ok(conn) = self.pool.acquire().await { @@ -103,7 +103,7 @@ impl SqlxMySqlPoolConnection { } } - pub async fn begin(&self) -> Result, DbErr> { + pub async fn begin(&self) -> Result { if let Ok(conn) = self.pool.acquire().await { DatabaseTransaction::new_mysql(conn).await } else { @@ -113,13 +113,11 @@ impl SqlxMySqlPoolConnection { } } - pub async fn transaction<'a, F, T, E/*, Fut*/>(&'a self, callback: F) -> Result> + pub async fn transaction(&self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + 'b>>, - // F: FnOnce(&DatabaseTransaction<'_>) -> Fut + Send, - // Fut: Future> + Send, - // T: Send, - E: std::error::Error, + F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, + T: Send, + E: std::error::Error + Send, { if let Ok(conn) = self.pool.acquire().await { let transaction = DatabaseTransaction::new_mysql(conn).await.map_err(|e| TransactionError::Connection(e))?; diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index a1c9d5720..c9949375b 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -91,7 +91,7 @@ impl SqlxPostgresPoolConnection { } } - pub async fn stream(&self, stmt: Statement) -> Result, DbErr> { + pub async fn stream(&self, stmt: Statement) -> Result { debug_print!("{}", stmt); if let Ok(conn) = self.pool.acquire().await { @@ -103,7 +103,7 @@ impl SqlxPostgresPoolConnection { } } - pub async fn begin(&self) -> Result, DbErr> { + pub async fn begin(&self) -> Result { if let Ok(conn) = self.pool.acquire().await { DatabaseTransaction::new_postgres(conn).await } else { @@ -113,13 +113,11 @@ impl SqlxPostgresPoolConnection { } } - pub async fn transaction<'a, F, T, E/*, Fut*/>(&'a self, callback: F) -> Result> + pub async fn transaction(&self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + 'b>>, - // F: FnOnce(&DatabaseTransaction<'_>) -> Fut + Send, - // Fut: Future> + Send, - // T: Send, - E: std::error::Error, + F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, + T: Send, + E: std::error::Error + Send, { if let Ok(conn) = self.pool.acquire().await { let transaction = DatabaseTransaction::new_postgres(conn).await.map_err(|e| TransactionError::Connection(e))?; diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index e8d6f47c8..bf06a2659 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -91,7 +91,7 @@ impl SqlxSqlitePoolConnection { } } - pub async fn stream(&self, stmt: Statement) -> Result, DbErr> { + pub async fn stream(&self, stmt: Statement) -> Result { debug_print!("{}", stmt); if let Ok(conn) = self.pool.acquire().await { @@ -103,7 +103,7 @@ impl SqlxSqlitePoolConnection { } } - pub async fn begin(&self) -> Result, DbErr> { + pub async fn begin(&self) -> Result { if let Ok(conn) = self.pool.acquire().await { DatabaseTransaction::new_sqlite(conn).await } else { @@ -113,13 +113,11 @@ impl SqlxSqlitePoolConnection { } } - pub async fn transaction<'a, F, T, E/*, Fut*/>(&'a self, callback: F) -> Result> + pub async fn transaction(&self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + 'b>>, - // F: FnOnce(&DatabaseTransaction<'_>) -> Fut + Send, - // Fut: Future> + Send, - // T: Send, - E: std::error::Error, + F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, + T: Send, + E: std::error::Error + Send, { if let Ok(conn) = self.pool.acquire().await { let transaction = DatabaseTransaction::new_sqlite(conn).await.map_err(|e| TransactionError::Connection(e))?; diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index aa2b0737d..6979b9e32 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -51,7 +51,7 @@ where ActiveValue::unchanged(value) } -#[async_trait(?Send)] +#[async_trait] pub trait ActiveModelTrait: Clone + Debug { type Entity: EntityTrait; @@ -67,10 +67,10 @@ pub trait ActiveModelTrait: Clone + Debug { fn default() -> Self; - async fn insert<'a, 'b: 'a, C>(self, db: &'a C) -> Result + async fn insert<'a, C>(self, db: &'a C) -> Result where ::Model: IntoActiveModel, - C: ConnectionTrait<'b>, + C: ConnectionTrait<'a> + Sync, Self: 'a, { let am = self; @@ -92,9 +92,9 @@ pub trait ActiveModelTrait: Clone + Debug { } } - async fn update<'a, 'b: 'a, C>(self, db: &'a C) -> Result + async fn update<'a, C>(self, db: &'a C) -> Result where - C: ConnectionTrait<'b>, + C: ConnectionTrait<'a> + Sync, Self: 'a, { let exec = Self::Entity::update(self).exec(db); @@ -103,11 +103,11 @@ pub trait ActiveModelTrait: Clone + Debug { /// Insert the model if primary key is unset, update otherwise. /// Only works if the entity has auto increment primary key. - async fn save<'a, 'b: 'a, C>(self, db: &'a C) -> Result + async fn save<'a, C>(self, db: &'a C) -> Result where Self: ActiveModelBehavior + 'a, ::Model: IntoActiveModel, - C: ConnectionTrait<'b>, + C: ConnectionTrait<'a> + Sync, { let mut am = self; am = ActiveModelBehavior::before_save(am); @@ -132,7 +132,7 @@ pub trait ActiveModelTrait: Clone + Debug { async fn delete<'a, C>(self, db: &'a C) -> Result where Self: ActiveModelBehavior + 'a, - C: ConnectionTrait<'a>, + C: ConnectionTrait<'a> + Sync, { let mut am = self; am = ActiveModelBehavior::before_delete(am); diff --git a/src/executor/insert.rs b/src/executor/insert.rs index 4f699ac84..3f5ced1d4 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -24,12 +24,12 @@ where A: ActiveModelTrait, { #[allow(unused_mut)] - pub async fn exec<'a, 'b: 'a, C>( + pub async fn exec<'a, C>( self, db: &'a C, ) -> Result, DbErr> where - C: ConnectionTrait<'b>, + C: ConnectionTrait<'a>, A: 'a, { // TODO: extract primary key's value from query @@ -62,12 +62,12 @@ where } } - pub async fn exec<'a, 'b: 'a, C>( + pub async fn exec<'a, C>( self, db: &'a C, ) -> Result, DbErr> where - C: ConnectionTrait<'b>, + C: ConnectionTrait<'a>, A: 'a, { let builder = db.get_database_backend(); @@ -76,12 +76,12 @@ where } // Only Statement impl Send -async fn exec_insert<'a, 'b: 'a, A, C>( +async fn exec_insert<'a, A, C>( statement: Statement, db: &C, ) -> Result, DbErr> where - C: ConnectionTrait<'b>, + C: ConnectionTrait<'a>, A: ActiveModelTrait, { type PrimaryKey = <::Entity as EntityTrait>::PrimaryKey; diff --git a/src/executor/update.rs b/src/executor/update.rs index 3051d53e7..06cd514eb 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -16,7 +16,7 @@ impl<'a, A: 'a> UpdateOne where A: ActiveModelTrait, { - pub async fn exec<'b: 'a, C>(self, db: &'a C) -> Result + pub async fn exec<'b, C>(self, db: &'b C) -> Result where C: ConnectionTrait<'b> { // so that self is dropped before entering await exec_update_and_return_original(self.query, self.model, db).await @@ -42,11 +42,11 @@ impl Updater { Self { query } } - pub async fn exec<'a, 'b: 'a, C>( + pub async fn exec<'a, C>( self, db: &'a C, ) -> Result - where C: ConnectionTrait<'b> { + where C: ConnectionTrait<'a> { let builder = db.get_database_backend(); exec_update(builder.build(&self.query), db).await } @@ -60,22 +60,22 @@ where C: ConnectionTrait<'a> { Updater::new(query).exec(db).await } -async fn exec_update_and_return_original<'a, 'b: 'a, A, C>( +async fn exec_update_and_return_original<'a, A, C>( query: UpdateStatement, model: A, db: &'a C, ) -> Result where A: ActiveModelTrait, - C: ConnectionTrait<'b>, + C: ConnectionTrait<'a>, { Updater::new(query).exec(db).await?; Ok(model) } // Only Statement impl Send -async fn exec_update<'a, 'b: 'a, C>(statement: Statement, db: &'a C) -> Result -where C: ConnectionTrait<'b> { +async fn exec_update<'a, C>(statement: Statement, db: &'a C) -> Result +where C: ConnectionTrait<'a> { let result = db.execute(statement).await?; Ok(UpdateResult { rows_affected: result.rows_affected(), diff --git a/tests/stream_tests.rs b/tests/stream_tests.rs index a6c22dc89..969b93e1d 100644 --- a/tests/stream_tests.rs +++ b/tests/stream_tests.rs @@ -23,7 +23,7 @@ pub async fn stream() -> Result<(), DbErr> { .await?; let result = Bakery::find_by_id(bakery.id.clone().unwrap()) - .stream(&ctx.db) + .stream(&ctx.db) .await? .next() .await diff --git a/tests/transaction_tests.rs b/tests/transaction_tests.rs index e0998940c..539eaefcc 100644 --- a/tests/transaction_tests.rs +++ b/tests/transaction_tests.rs @@ -60,7 +60,7 @@ pub async fn transaction_with_reference() { ctx.delete().await; } -fn _transaction_with_reference<'a>(txn: &'a DatabaseTransaction<'_>, name1: &'a str, name2: &'a str, search_name: &'a str) -> std::pin::Pin> + 'a>> { +fn _transaction_with_reference<'a>(txn: &'a DatabaseTransaction, name1: &'a str, name2: &'a str, search_name: &'a str) -> std::pin::Pin> + Send + 'a>> { Box::pin(async move { let _ = bakery::ActiveModel { name: Set(name1.to_owned()), From 105d48d21dd1e95b2d4f685ee9d90299f1209f89 Mon Sep 17 00:00:00 2001 From: Marco Napetti Date: Wed, 29 Sep 2021 10:55:46 +0200 Subject: [PATCH 08/12] Solve clippy lints --- src/database/stream/query.rs | 13 ++++++++----- src/database/stream/transaction.rs | 13 ++++++++----- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/database/stream/query.rs b/src/database/stream/query.rs index 0bc4dd6c8..553d9f7b0 100644 --- a/src/database/stream/query.rs +++ b/src/database/stream/query.rs @@ -1,10 +1,13 @@ use std::{pin::Pin, task::Poll, sync::Arc}; -use futures::{Stream, TryStreamExt}; +use futures::Stream; +#[cfg(feature = "sqlx-dep")] +use futures::TryStreamExt; +#[cfg(feature = "sqlx-dep")] use sqlx::{pool::PoolConnection, Executor}; -use crate::{DbErr, InnerConnection, QueryResult, Statement, sqlx_error_to_query_err}; +use crate::{DbErr, InnerConnection, QueryResult, Statement}; #[ouroboros::self_referencing] pub struct QueryStream { @@ -62,7 +65,7 @@ impl QueryStream { Box::pin( c.fetch(query) .map_ok(Into::into) - .map_err(sqlx_error_to_query_err) + .map_err(crate::sqlx_error_to_query_err) ) }, #[cfg(feature = "sqlx-postgres")] @@ -71,7 +74,7 @@ impl QueryStream { Box::pin( c.fetch(query) .map_ok(Into::into) - .map_err(sqlx_error_to_query_err) + .map_err(crate::sqlx_error_to_query_err) ) }, #[cfg(feature = "sqlx-sqlite")] @@ -80,7 +83,7 @@ impl QueryStream { Box::pin( c.fetch(query) .map_ok(Into::into) - .map_err(sqlx_error_to_query_err) + .map_err(crate::sqlx_error_to_query_err) ) }, #[cfg(feature = "mock")] diff --git a/src/database/stream/transaction.rs b/src/database/stream/transaction.rs index 30ec810ba..d945f4095 100644 --- a/src/database/stream/transaction.rs +++ b/src/database/stream/transaction.rs @@ -1,12 +1,15 @@ use std::{ops::DerefMut, pin::Pin, task::Poll}; -use futures::{Stream, TryStreamExt}; +use futures::Stream; +#[cfg(feature = "sqlx-dep")] +use futures::TryStreamExt; +#[cfg(feature = "sqlx-dep")] use sqlx::Executor; use futures::lock::MutexGuard; -use crate::{DbErr, InnerConnection, QueryResult, Statement, sqlx_error_to_query_err}; +use crate::{DbErr, InnerConnection, QueryResult, Statement}; #[ouroboros::self_referencing] pub struct TransactionStream<'a> { @@ -36,7 +39,7 @@ impl<'a> TransactionStream<'a> { Box::pin( c.fetch(query) .map_ok(Into::into) - .map_err(sqlx_error_to_query_err) + .map_err(crate::sqlx_error_to_query_err) ) as Pin>>> }, #[cfg(feature = "sqlx-postgres")] @@ -45,7 +48,7 @@ impl<'a> TransactionStream<'a> { Box::pin( c.fetch(query) .map_ok(Into::into) - .map_err(sqlx_error_to_query_err) + .map_err(crate::sqlx_error_to_query_err) ) as Pin>>> }, #[cfg(feature = "sqlx-sqlite")] @@ -54,7 +57,7 @@ impl<'a> TransactionStream<'a> { Box::pin( c.fetch(query) .map_ok(Into::into) - .map_err(sqlx_error_to_query_err) + .map_err(crate::sqlx_error_to_query_err) ) as Pin>>> }, #[cfg(feature = "mock")] From b655de753887c4a952ddf4d8b3f83a0e0e7dc90e Mon Sep 17 00:00:00 2001 From: Marco Napetti Date: Wed, 29 Sep 2021 11:59:39 +0200 Subject: [PATCH 09/12] panic on drop of a locked transaction --- src/database/db_transaction.rs | 40 +++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/database/db_transaction.rs b/src/database/db_transaction.rs index fba8b857a..818ad5173 100644 --- a/src/database/db_transaction.rs +++ b/src/database/db_transaction.rs @@ -130,24 +130,28 @@ impl DatabaseTransaction { // the rollback is queued and will be performed on next async operation, like returning the connection to the pool fn start_rollback(&mut self) { if self.open { - match Arc::get_mut(&mut self.conn).map(|o| o.get_mut()) { - #[cfg(feature = "sqlx-mysql")] - Some(InnerConnection::MySql(c)) => { - ::TransactionManager::start_rollback(c); - }, - #[cfg(feature = "sqlx-postgres")] - Some(InnerConnection::Postgres(c)) => { - ::TransactionManager::start_rollback(c); - }, - #[cfg(feature = "sqlx-sqlite")] - Some(InnerConnection::Sqlite(c)) => { - ::TransactionManager::start_rollback(c); - }, - //Should we do something for mocked connections? - #[cfg(feature = "mock")] - Some(InnerConnection::Mock(_)) => {}, - //this happens if this is a nested transaction - None => unreachable!(), + if let Some(conn) = self.conn.try_lock() { + match *conn { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + ::TransactionManager::start_rollback(c); + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + ::TransactionManager::start_rollback(c); + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + ::TransactionManager::start_rollback(c); + }, + //Should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {}, + } + } + else { + //this should never happen + panic!("Dropping a locked Transaction"); } } } From 34a1dce5a862b35b3460e7d73a2dcb1b45ea8552 Mon Sep 17 00:00:00 2001 From: Marco Napetti Date: Wed, 29 Sep 2021 12:09:10 +0200 Subject: [PATCH 10/12] panic on drop of a locked transaction --- src/database/db_transaction.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/database/db_transaction.rs b/src/database/db_transaction.rs index 818ad5173..70ec488cd 100644 --- a/src/database/db_transaction.rs +++ b/src/database/db_transaction.rs @@ -131,7 +131,7 @@ impl DatabaseTransaction { fn start_rollback(&mut self) { if self.open { if let Some(conn) = self.conn.try_lock() { - match *conn { + match &mut *conn { #[cfg(feature = "sqlx-mysql")] InnerConnection::MySql(c) => { ::TransactionManager::start_rollback(c); From 34a82c2dcdab97347adf2ca6c02e02a1e57481cd Mon Sep 17 00:00:00 2001 From: Marco Napetti Date: Wed, 29 Sep 2021 12:09:28 +0200 Subject: [PATCH 11/12] panic on drop of a locked transaction --- src/database/db_transaction.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/database/db_transaction.rs b/src/database/db_transaction.rs index 70ec488cd..b403971eb 100644 --- a/src/database/db_transaction.rs +++ b/src/database/db_transaction.rs @@ -130,7 +130,7 @@ impl DatabaseTransaction { // the rollback is queued and will be performed on next async operation, like returning the connection to the pool fn start_rollback(&mut self) { if self.open { - if let Some(conn) = self.conn.try_lock() { + if let Some(mut conn) = self.conn.try_lock() { match &mut *conn { #[cfg(feature = "sqlx-mysql")] InnerConnection::MySql(c) => { From 4a2bc91cfdb213c50a3082e1e5598e01b8f70687 Mon Sep 17 00:00:00 2001 From: Marco Napetti Date: Wed, 29 Sep 2021 14:31:34 +0200 Subject: [PATCH 12/12] Centralize Sync requirement --- src/database/db_connection.rs | 2 +- src/entity/active_model.rs | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 391a40d5d..569f38965 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -16,7 +16,7 @@ pub(crate) enum InnerConnection { } #[async_trait::async_trait] -pub trait ConnectionTrait<'a> { +pub trait ConnectionTrait<'a>: Sync { type Stream: Stream>; fn get_database_backend(&self) -> DbBackend; diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index 6979b9e32..32e9d77df 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -70,7 +70,7 @@ pub trait ActiveModelTrait: Clone + Debug { async fn insert<'a, C>(self, db: &'a C) -> Result where ::Model: IntoActiveModel, - C: ConnectionTrait<'a> + Sync, + C: ConnectionTrait<'a>, Self: 'a, { let am = self; @@ -94,7 +94,7 @@ pub trait ActiveModelTrait: Clone + Debug { async fn update<'a, C>(self, db: &'a C) -> Result where - C: ConnectionTrait<'a> + Sync, + C: ConnectionTrait<'a>, Self: 'a, { let exec = Self::Entity::update(self).exec(db); @@ -107,7 +107,7 @@ pub trait ActiveModelTrait: Clone + Debug { where Self: ActiveModelBehavior + 'a, ::Model: IntoActiveModel, - C: ConnectionTrait<'a> + Sync, + C: ConnectionTrait<'a>, { let mut am = self; am = ActiveModelBehavior::before_save(am); @@ -132,7 +132,7 @@ pub trait ActiveModelTrait: Clone + Debug { async fn delete<'a, C>(self, db: &'a C) -> Result where Self: ActiveModelBehavior + 'a, - C: ConnectionTrait<'a> + Sync, + C: ConnectionTrait<'a>, { let mut am = self; am = ActiveModelBehavior::before_delete(am);