Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pool: fix panic when using callbacks #1915

Merged
merged 2 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sqlx-core/src/mssql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ pub use value::{MssqlValue, MssqlValueRef};
/// An alias for [`Pool`][crate::pool::Pool], specialized for MSSQL.
pub type MssqlPool = crate::pool::Pool<Mssql>;

/// An alias for [`PoolOptions`][crate::pool::PoolOptions], specialized for MSSQL.
pub type MssqlPoolOptions = crate::pool::PoolOptions<Mssql>;

/// An alias for [`Executor<'_, Database = Mssql>`][Executor].
pub trait MssqlExecutor<'c>: Executor<'c, Database = Mssql> {}
impl<'c, T: Executor<'c, Database = Mssql>> MssqlExecutor<'c> for T {}
Expand Down
6 changes: 4 additions & 2 deletions sqlx-core/src/pool/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,10 @@ impl<DB: Database> Floating<DB, Idle<DB>> {
let now = Instant::now();

PoolConnectionMetadata {
age: self.created_at.duration_since(now),
idle_for: self.idle_since.duration_since(now),
// NOTE: the receiver is the later `Instant` and the arg is the earlier
// https://github.com/launchbadge/sqlx/issues/1912
age: now.saturating_duration_since(self.created_at),
idle_for: now.saturating_duration_since(self.idle_since),
}
}
}
Expand Down
148 changes: 147 additions & 1 deletion tests/any/pool.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use sqlx::any::AnyPoolOptions;
use sqlx::any::{AnyConnectOptions, AnyKind, AnyPoolOptions};
use sqlx::Executor;
use std::sync::atomic::AtomicI32;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
Expand Down Expand Up @@ -64,3 +66,147 @@ async fn pool_should_be_returned_failed_transactions() -> anyhow::Result<()> {

Ok(())
}

#[sqlx_macros::test]
async fn test_pool_callbacks() -> anyhow::Result<()> {
#[derive(sqlx::FromRow, Debug, PartialEq, Eq)]
struct ConnStats {
id: i32,
before_acquire_calls: i32,
after_release_calls: i32,
}

sqlx_test::setup_if_needed();

let conn_options: AnyConnectOptions = std::env::var("DATABASE_URL")?.parse()?;

#[cfg(feature = "mssql")]
if conn_options.kind() == AnyKind::Mssql {
// MSSQL doesn't support `CREATE TEMPORARY TABLE`,
// because why follow conventions when you can subvert them?
// Instead, you prepend `#` to the table name for a session-local temporary table
// which you also have to do when referencing it.

// Since that affects basically every query here,
// it's just easier to have a separate MSSQL-specific test case.
return Ok(());
}

let current_id = AtomicI32::new(0);

let pool = AnyPoolOptions::new()
.max_connections(1)
.acquire_timeout(Duration::from_secs(5))
.after_connect(move |conn, meta| {
assert_eq!(meta.age, Duration::ZERO);
assert_eq!(meta.idle_for, Duration::ZERO);

let id = current_id.fetch_add(1, Ordering::AcqRel);

Box::pin(async move {
let statement = format!(
// language=SQL
r#"
CREATE TEMPORARY TABLE conn_stats(
id int primary key,
before_acquire_calls int default 0,
after_release_calls int default 0
);
INSERT INTO conn_stats(id) VALUES ({});
"#,
// Until we have generalized bind parameters
id
);

conn.execute(&statement[..]).await?;
Ok(())
})
})
.before_acquire(|conn, meta| {
// `age` and `idle_for` should both be nonzero
assert_ne!(meta.age, Duration::ZERO);
assert_ne!(meta.idle_for, Duration::ZERO);

Box::pin(async move {
// MySQL and MariaDB don't support UPDATE ... RETURNING
sqlx::query(
r#"
UPDATE conn_stats
SET before_acquire_calls = before_acquire_calls + 1
"#,
)
.execute(&mut *conn)
.await?;

let stats: ConnStats = sqlx::query_as("SELECT * FROM conn_stats")
.fetch_one(conn)
.await?;

// For even IDs, cap by the number of before_acquire calls.
// Ignore the check for odd IDs.
Ok((stats.id & 1) == 1 || stats.before_acquire_calls < 3)
})
})
.after_release(|conn, meta| {
// `age` should be nonzero but `idle_for` should be zero.
assert_ne!(meta.age, Duration::ZERO);
assert_eq!(meta.idle_for, Duration::ZERO);

Box::pin(async move {
sqlx::query(
r#"
UPDATE conn_stats
SET after_release_calls = after_release_calls + 1
"#,
)
.execute(&mut *conn)
.await?;

let stats: ConnStats = sqlx::query_as("SELECT * FROM conn_stats")
.fetch_one(conn)
.await?;

// For odd IDs, cap by the number of before_release calls.
// Ignore the check for even IDs.
Ok((stats.id & 1) == 0 || stats.after_release_calls < 4)
})
})
// Don't establish a connection yet.
.connect_lazy_with(conn_options);

// Expected pattern of (id, before_acquire_calls, after_release_calls)
let pattern = [
// The connection pool starts empty.
(0, 0, 0),
(0, 1, 1),
(0, 2, 2),
(1, 0, 0),
(1, 1, 1),
(1, 2, 2),
// We should expect one more `acquire` because the ID is odd
(1, 3, 3),
(2, 0, 0),
(2, 1, 1),
(2, 2, 2),
(3, 0, 0),
];

for (id, before_acquire_calls, after_release_calls) in pattern {
let conn_stats: ConnStats = sqlx::query_as("SELECT * FROM conn_stats")
.fetch_one(&pool)
.await?;

assert_eq!(
conn_stats,
ConnStats {
id,
before_acquire_calls,
after_release_calls
}
);
}

pool.close().await;

Ok(())
}
136 changes: 135 additions & 1 deletion tests/mssql/mssql.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use futures::TryStreamExt;
use sqlx::mssql::Mssql;
use sqlx::mssql::{Mssql, MssqlPoolOptions};
use sqlx::{Column, Connection, Executor, MssqlConnection, Row, Statement, TypeInfo};
use sqlx_core::mssql::MssqlRow;
use sqlx_test::new;
use std::sync::atomic::{AtomicI32, Ordering};
use std::time::Duration;

#[sqlx_macros::test]
async fn it_connects() -> anyhow::Result<()> {
Expand Down Expand Up @@ -325,3 +327,135 @@ async fn it_can_prepare_then_execute() -> anyhow::Result<()> {

Ok(())
}

// MSSQL-specific copy of the test case in `tests/any/pool.rs`
// because MSSQL has its own bespoke syntax for temporary tables.
#[sqlx_macros::test]
async fn test_pool_callbacks() -> anyhow::Result<()> {
#[derive(sqlx::FromRow, Debug, PartialEq, Eq)]
struct ConnStats {
id: i32,
before_acquire_calls: i32,
after_release_calls: i32,
}

sqlx_test::setup_if_needed();

let current_id = AtomicI32::new(0);

let pool = MssqlPoolOptions::new()
.max_connections(1)
.acquire_timeout(Duration::from_secs(5))
.after_connect(move |conn, meta| {
assert_eq!(meta.age, Duration::ZERO);
assert_eq!(meta.idle_for, Duration::ZERO);

let id = current_id.fetch_add(1, Ordering::AcqRel);

Box::pin(async move {
let statement = format!(
// language=MSSQL
r#"
CREATE TABLE #conn_stats(
id int primary key,
before_acquire_calls int default 0,
after_release_calls int default 0
);
INSERT INTO #conn_stats(id) VALUES ({});
"#,
// Until we have generalized bind parameters
id
);

conn.execute(&statement[..]).await?;
Ok(())
})
})
.before_acquire(|conn, meta| {
// `age` and `idle_for` should both be nonzero
assert_ne!(meta.age, Duration::ZERO);
assert_ne!(meta.idle_for, Duration::ZERO);

Box::pin(async move {
// MSSQL doesn't support UPDATE ... RETURNING either
sqlx::query(
r#"
UPDATE #conn_stats
SET before_acquire_calls = before_acquire_calls + 1
"#,
)
.execute(&mut *conn)
.await?;

let stats: ConnStats = sqlx::query_as("SELECT * FROM #conn_stats")
.fetch_one(conn)
.await?;

// For even IDs, cap by the number of before_acquire calls.
// Ignore the check for odd IDs.
Ok((stats.id & 1) == 1 || stats.before_acquire_calls < 3)
})
})
.after_release(|conn, meta| {
// `age` should be nonzero but `idle_for` should be zero.
assert_ne!(meta.age, Duration::ZERO);
assert_eq!(meta.idle_for, Duration::ZERO);

Box::pin(async move {
sqlx::query(
r#"
UPDATE #conn_stats
SET after_release_calls = after_release_calls + 1
"#,
)
.execute(&mut *conn)
.await?;

let stats: ConnStats = sqlx::query_as("SELECT * FROM #conn_stats")
.fetch_one(conn)
.await?;

// For odd IDs, cap by the number of before_release calls.
// Ignore the check for even IDs.
Ok((stats.id & 1) == 0 || stats.after_release_calls < 4)
})
})
// Don't establish a connection yet.
.connect_lazy(&std::env::var("DATABASE_URL")?)?;

// Expected pattern of (id, before_acquire_calls, after_release_calls)
let pattern = [
// The connection pool starts empty.
(0, 0, 0),
(0, 1, 1),
(0, 2, 2),
(1, 0, 0),
(1, 1, 1),
(1, 2, 2),
// We should expect one more `acquire` because the ID is odd
(1, 3, 3),
(2, 0, 0),
(2, 1, 1),
(2, 2, 2),
(3, 0, 0),
];

for (id, before_acquire_calls, after_release_calls) in pattern {
let conn_stats: ConnStats = sqlx::query_as("SELECT * FROM #conn_stats")
.fetch_one(&pool)
.await?;

assert_eq!(
conn_stats,
ConnStats {
id,
before_acquire_calls,
after_release_calls
}
);
}

pool.close().await;

Ok(())
}