diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index bbcc43134e..b79dea3f90 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -8,6 +8,7 @@ use crossbeam_queue::ArrayQueue; use crate::sync::{AsyncSemaphore, AsyncSemaphoreReleaser}; +use std::borrow::Cow; use std::cmp; use std::future::Future; use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering}; @@ -324,6 +325,36 @@ impl PoolInner { Ok(acquired) } + /// Attempts to get connect options, possibly modify using before_connect, then connect. + /// + /// Wrapping this code in a timeout allows the total time taken for these steps to + /// be bounded by the connection deadline. + async fn get_connect_options_and_connect( + self: &Arc, + num_attempts: u32, + ) -> Result { + // clone the connect options arc so it can be used without holding the RwLockReadGuard + // across an async await point + let connect_options_arc = self + .connect_options + .read() + .expect("write-lock holder panicked") + .clone(); + + let connect_options = if let Some(callback) = &self.options.before_connect { + callback(connect_options_arc.as_ref(), num_attempts) + .await + .map_err(|error| { + tracing::error!(%error, "error returned from before_connect"); + error + })? + } else { + Cow::Borrowed(connect_options_arc.as_ref()) + }; + + connect_options.connect().await + } + pub(super) async fn connect( self: &Arc, deadline: Instant, @@ -335,21 +366,17 @@ impl PoolInner { let mut backoff = Duration::from_millis(10); let max_backoff = deadline_as_timeout(deadline)? / 5; + let mut num_attempts: u32 = 0; loop { let timeout = deadline_as_timeout(deadline)?; - - // clone the connect options arc so it can be used without holding the RwLockReadGuard - // across an async await point - let connect_options = self - .connect_options - .read() - .expect("write-lock holder panicked") - .clone(); + num_attempts += 1; // result here is `Result, TimeoutError>` // if this block does not return, sleep for the backoff timeout and try again - match crate::rt::timeout(timeout, connect_options.connect()).await { + match crate::rt::timeout(timeout, self.get_connect_options_and_connect(num_attempts)) + .await + { // successfully established connection Ok(Ok(mut raw)) => { // See comment on `PoolOptions::after_connect` diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index 96dbf8ee3d..17b2c8ea04 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -5,6 +5,7 @@ use crate::pool::inner::PoolInner; use crate::pool::Pool; use futures_core::future::BoxFuture; use log::LevelFilter; +use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -44,6 +45,18 @@ use std::time::{Duration, Instant}; /// the perspectives of both API designer and consumer. pub struct PoolOptions { pub(crate) test_before_acquire: bool, + pub(crate) before_connect: Option< + Arc< + dyn Fn( + &::Options, + u32, + ) + -> BoxFuture<'_, Result::Options>, Error>> + + 'static + + Send + + Sync, + >, + >, pub(crate) after_connect: Option< Arc< dyn Fn(&mut DB::Connection, PoolConnectionMetadata) -> BoxFuture<'_, Result<(), Error>> @@ -94,6 +107,7 @@ impl Clone for PoolOptions { fn clone(&self) -> Self { PoolOptions { test_before_acquire: self.test_before_acquire, + before_connect: self.before_connect.clone(), after_connect: self.after_connect.clone(), before_acquire: self.before_acquire.clone(), after_release: self.after_release.clone(), @@ -143,6 +157,7 @@ impl PoolOptions { pub fn new() -> Self { Self { // User-specifiable routines + before_connect: None, after_connect: None, before_acquire: None, after_release: None, @@ -339,6 +354,54 @@ impl PoolOptions { self } + /// Perform an asynchronous action before connecting to the database. + /// + /// This operation is performed on every attempt to connect, including retries. The + /// current `ConnectOptions` is passed, and this may be passed unchanged, or modified + /// after cloning. The current connection attempt is passed as the second parameter + /// (starting at 1). + /// + /// If the operation returns with an error, then the connection attempt fails without + /// attempting further retries. The operation therefore may need to implement error + /// handling and/or value caching to avoid failing the connection attempt. + /// + /// # Example: Per-Request Authentication + /// This callback may be used to modify values in the database's `ConnectOptions`, before + /// connecting to the database. + /// + /// This example is written for PostgreSQL but can likely be adapted to other databases. + /// + /// ```no_run + /// # async fn f() -> Result<(), Box> { + /// use std::borrow::Cow; + /// use sqlx::Executor; + /// use sqlx::postgres::PgPoolOptions; + /// + /// let pool = PgPoolOptions::new() + /// .before_connect(move |opts, _num_attempts| Box::pin(async move { + /// Ok(Cow::Owned(opts.clone().password("abc"))) + /// })) + /// .connect("postgres:// …").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// For a discussion on why `Box::pin()` is required, see [the type-level docs][Self]. + pub fn before_connect(mut self, callback: F) -> Self + where + for<'c> F: Fn( + &'c ::Options, + u32, + ) + -> BoxFuture<'c, crate::Result::Options>>> + + 'static + + Send + + Sync, + { + self.before_connect = Some(Arc::new(callback)); + self + } + /// Perform an asynchronous action after connecting to the database. /// /// If the operation returns with an error then the error is logged, the connection is closed