Skip to content

Commit

Permalink
pool: fix panic when using callbacks
Browse files Browse the repository at this point in the history
add regression test
  • Loading branch information
abonander committed Jun 17, 2022
1 parent 339e058 commit 6a8149c
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 2 deletions.
6 changes: 4 additions & 2 deletions sqlx-core/src/pool/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,10 @@ impl<DB: Database> Floating<DB, Idle<DB>> {
let now = Instant::now();

PoolConnectionMetadata {
age: self.created_at.duration_since(now),
idle_for: self.idle_since.duration_since(now),
// NOTE: the receiver is the later `Instant` and the arg is the earlier
// https://github.com/launchbadge/sqlx/issues/1912
age: now.saturating_duration_since(self.created_at),
idle_for: now.saturating_duration_since(self.idle_since),
}
}
}
Expand Down
125 changes: 125 additions & 0 deletions tests/any/pool.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use sqlx::any::AnyPoolOptions;
use sqlx::Executor;
use std::sync::atomic::AtomicI32;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
Expand Down Expand Up @@ -64,3 +66,126 @@ async fn pool_should_be_returned_failed_transactions() -> anyhow::Result<()> {

Ok(())
}

#[sqlx_macros::test]
async fn test_pool_callbacks() -> anyhow::Result<()> {
sqlx_test::setup_if_needed();

#[derive(sqlx::FromRow, Debug, PartialEq, Eq)]
struct ConnStats {
id: i32,
before_acquire_calls: i32,
after_release_calls: i32,
}

let current_id = AtomicI32::new(0);

let pool = AnyPoolOptions::new()
.max_connections(1)
.acquire_timeout(Duration::from_secs(5))
.after_connect(move |conn, meta| {
assert_eq!(meta.age, Duration::ZERO);
assert_eq!(meta.idle_for, Duration::ZERO);

let id = current_id.fetch_add(1, Ordering::AcqRel);

Box::pin(async move {
let statement = format!(
// language=SQL
r#"
CREATE TEMPORARY TABLE conn_stats(
id int primary key,
before_acquire_calls int default 0,
after_release_calls int default 0
);
INSERT INTO conn_stats(id) VALUES ({});
"#,
// Until we have generalized bind parameters
id
);

conn.execute(&statement[..]).await?;
Ok(())
})
})
.before_acquire(|conn, meta| {
// `age` and `idle_for` should both be nonzero
assert_ne!(meta.age, Duration::ZERO);
assert_ne!(meta.idle_for, Duration::ZERO);

Box::pin(async move {
let stats: ConnStats = sqlx::query_as(
r#"
UPDATE conn_stats
SET before_acquire_calls = before_acquire_calls + 1
RETURNING *
"#,
)
.fetch_one(conn)
.await?;

// For even IDs, cap by the number of before_acquire calls.
// Ignore the check for odd IDs.
Ok((stats.id & 1) == 1 || stats.before_acquire_calls < 3)
})
})
.after_release(|conn, meta| {
// `age` should be nonzero but `idle_for` should be zero.
assert_ne!(meta.age, Duration::ZERO);
assert_eq!(meta.idle_for, Duration::ZERO);

Box::pin(async move {
let stats: ConnStats = sqlx::query_as(
r#"
UPDATE conn_stats
SET after_release_calls = after_release_calls + 1
RETURNING *
"#,
)
.fetch_one(conn)
.await?;

// For odd IDs, cap by the number of before_release calls.
// Ignore the check for even IDs.
Ok((stats.id & 1) == 0 || stats.after_release_calls < 4)
})
})
// Don't establish a connection yet.
.connect_lazy(&dotenv::var("DATABASE_URL")?)?;

// Expected pattern of (id, before_acquire_calls, after_release_calls)
let pattern = [
// The connection pool starts empty.
(0, 0, 0),
(0, 1, 1),
(0, 2, 2),
(1, 0, 0),
(1, 1, 1),
(1, 2, 2),
// We should expect one more `acquire` because the ID is odd
(1, 3, 3),
(2, 0, 0),
(2, 1, 1),
(2, 2, 2),
(3, 0, 0),
];

for (id, before_acquire_calls, after_release_calls) in pattern {
let conn_stats: ConnStats = sqlx::query_as("SELECT * FROM conn_stats")
.fetch_one(&pool)
.await?;

assert_eq!(
conn_stats,
ConnStats {
id,
before_acquire_calls,
after_release_calls
}
);
}

pool.close().await;

Ok(())
}

0 comments on commit 6a8149c

Please sign in to comment.