Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement before_connect callback to modify connect options. #3562

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions sqlx-core/src/pool/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crossbeam_queue::ArrayQueue;

use crate::sync::{AsyncSemaphore, AsyncSemaphoreReleaser};

use std::borrow::Cow;
use std::cmp;
use std::future::Future;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
Expand Down Expand Up @@ -324,6 +325,36 @@ impl<DB: Database> PoolInner<DB> {
Ok(acquired)
}

/// Attempts to get connect options, possibly modify using before_connect, then connect.
///
/// Wrapping this code in a timeout allows the total time taken for these steps to
/// be bounded by the connection deadline.
async fn get_connect_options_and_connect(
self: &Arc<Self>,
num_attempts: u32,
) -> Result<DB::Connection, Error> {
// clone the connect options arc so it can be used without holding the RwLockReadGuard
// across an async await point
let connect_options_arc = self
.connect_options
.read()
.expect("write-lock holder panicked")
.clone();

let connect_options = if let Some(callback) = &self.options.before_connect {
callback(connect_options_arc.as_ref(), num_attempts)
.await
.map_err(|error| {
tracing::error!(%error, "error returned from before_connect");
error
})?
} else {
Cow::Borrowed(connect_options_arc.as_ref())
};

connect_options.connect().await
}

pub(super) async fn connect(
self: &Arc<Self>,
deadline: Instant,
Expand All @@ -335,21 +366,17 @@ impl<DB: Database> PoolInner<DB> {

let mut backoff = Duration::from_millis(10);
let max_backoff = deadline_as_timeout(deadline)? / 5;
let mut num_attempts: u32 = 0;

loop {
let timeout = deadline_as_timeout(deadline)?;

// clone the connect options arc so it can be used without holding the RwLockReadGuard
// across an async await point
let connect_options = self
.connect_options
.read()
.expect("write-lock holder panicked")
.clone();
num_attempts += 1;

// result here is `Result<Result<C, Error>, TimeoutError>`
// if this block does not return, sleep for the backoff timeout and try again
match crate::rt::timeout(timeout, connect_options.connect()).await {
match crate::rt::timeout(timeout, self.get_connect_options_and_connect(num_attempts))
.await
{
// successfully established connection
Ok(Ok(mut raw)) => {
// See comment on `PoolOptions::after_connect`
Expand Down
63 changes: 63 additions & 0 deletions sqlx-core/src/pool/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::pool::inner::PoolInner;
use crate::pool::Pool;
use futures_core::future::BoxFuture;
use log::LevelFilter;
use std::borrow::Cow;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use std::time::{Duration, Instant};
Expand Down Expand Up @@ -44,6 +45,18 @@ use std::time::{Duration, Instant};
/// the perspectives of both API designer and consumer.
pub struct PoolOptions<DB: Database> {
pub(crate) test_before_acquire: bool,
pub(crate) before_connect: Option<
Arc<
dyn Fn(
&<DB::Connection as Connection>::Options,
u32,
)
-> BoxFuture<'_, Result<Cow<'_, <DB::Connection as Connection>::Options>, Error>>
+ 'static
+ Send
+ Sync,
>,
>,
pub(crate) after_connect: Option<
Arc<
dyn Fn(&mut DB::Connection, PoolConnectionMetadata) -> BoxFuture<'_, Result<(), Error>>
Expand Down Expand Up @@ -94,6 +107,7 @@ impl<DB: Database> Clone for PoolOptions<DB> {
fn clone(&self) -> Self {
PoolOptions {
test_before_acquire: self.test_before_acquire,
before_connect: self.before_connect.clone(),
after_connect: self.after_connect.clone(),
before_acquire: self.before_acquire.clone(),
after_release: self.after_release.clone(),
Expand Down Expand Up @@ -143,6 +157,7 @@ impl<DB: Database> PoolOptions<DB> {
pub fn new() -> Self {
Self {
// User-specifiable routines
before_connect: None,
after_connect: None,
before_acquire: None,
after_release: None,
Expand Down Expand Up @@ -339,6 +354,54 @@ impl<DB: Database> PoolOptions<DB> {
self
}

/// Perform an asynchronous action before connecting to the database.
///
/// This operation is performed on every attempt to connect, including retries. The
/// current `ConnectOptions` is passed, and this may be passed unchanged, or modified
/// after cloning. The current connection attempt is passed as the second parameter
/// (starting at 1).
///
/// If the operation returns with an error, then the connection attempt fails without
/// attempting further retries. The operation therefore may need to implement error
/// handling and/or value caching to avoid failing the connection attempt.
///
/// # Example: Per-Request Authentication
/// This callback may be used to modify values in the database's `ConnectOptions`, before
/// connecting to the database.
///
/// This example is written for PostgreSQL but can likely be adapted to other databases.
///
/// ```no_run
/// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
/// use std::borrow::Cow;
/// use sqlx::Executor;
/// use sqlx::postgres::PgPoolOptions;
///
/// let pool = PgPoolOptions::new()
/// .before_connect(move |opts, _num_attempts| Box::pin(async move {
/// Ok(Cow::Owned(opts.clone().password("abc")))
/// }))
/// .connect("postgres:// …").await?;
/// # Ok(())
/// # }
/// ```
///
/// For a discussion on why `Box::pin()` is required, see [the type-level docs][Self].
pub fn before_connect<F>(mut self, callback: F) -> Self
where
for<'c> F: Fn(
&'c <DB::Connection as Connection>::Options,
u32,
)
-> BoxFuture<'c, crate::Result<Cow<'c, <DB::Connection as Connection>::Options>>>
+ 'static
+ Send
+ Sync,
{
self.before_connect = Some(Arc::new(callback));
self
}

/// Perform an asynchronous action after connecting to the database.
///
/// If the operation returns with an error then the error is logged, the connection is closed
Expand Down
Loading