Skip to content

Commit

Permalink
Merge pull request #4394 from weiznich/feature/share_statement_cache_…
Browse files Browse the repository at this point in the history
…with_diesel_async

Refactor the statement cache abstraction
  • Loading branch information
weiznich authored Jan 11, 2025
2 parents 2129cd9 + ece87f1 commit 25cff1d
Show file tree
Hide file tree
Showing 12 changed files with 313 additions and 108 deletions.
205 changes: 185 additions & 20 deletions diesel/src/connection/statement_cache/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,13 @@
use std::any::TypeId;
use std::borrow::Cow;
use std::collections::hash_map::Entry;
use std::hash::Hash;
use std::ops::{Deref, DerefMut};

use strategy::{StatementCacheStrategy, WithCacheStrategy, WithoutCacheStrategy};
use strategy::{
LookupStatementResult, StatementCacheStrategy, WithCacheStrategy, WithoutCacheStrategy,
};

use crate::backend::Backend;
use crate::connection::InstrumentationEvent;
Expand Down Expand Up @@ -151,8 +154,8 @@ pub enum PrepareForCache {
impl<DB, Statement> StatementCache<DB, Statement>
where
DB: Backend + 'static,
Statement: 'static,
DB::TypeMetadata: Clone,
Statement: Send + 'static,
DB::TypeMetadata: Send + Clone,
DB::QueryBuilder: Default,
StatementCacheKey<DB>: Hash + Eq,
{
Expand Down Expand Up @@ -195,57 +198,152 @@ where
/// parameter indicates if the constructed prepared statement will be cached or not.
/// See the [module](self) documentation for details
/// about which statements are cached and which are not cached.
//
// Notes:
// This function takes explicitly a connection and a function pointer (and no generic callback)
// as argument to ensure that we don't leak generic query types into the prepare function
#[allow(unreachable_pub)]
pub fn cached_statement<T, F>(
&mut self,
pub fn cached_statement<'a, T, R, C>(
&'a mut self,
source: &T,
backend: &DB,
bind_types: &[DB::TypeMetadata],
mut prepare_fn: F,
conn: C,
prepare_fn: fn(C, &str, PrepareForCache, &[DB::TypeMetadata]) -> R,
instrumentation: &mut dyn Instrumentation,
) -> QueryResult<MaybeCached<'_, Statement>>
) -> R::Return<'a>
where
T: QueryFragment<DB> + QueryId,
F: FnMut(&str, PrepareForCache) -> QueryResult<Statement>,
R: StatementCallbackReturnType<Statement, C> + 'a,
{
Self::cached_statement_non_generic(
self.cache.as_mut(),
self.cached_statement_non_generic(
T::query_id(),
source,
backend,
bind_types,
&mut |sql, is_cached| {
conn,
prepare_fn,
instrumentation,
)
}

/// Prepare a query as prepared statement
///
/// This function closely mirrors `Self::cached_statement` but
/// eliminates the generic query type in favour of a trait object
///
/// This can be easier to use in situations where you already turned
/// the query type into a concrete SQL string
// Notes:
// This function takes explicitly a connection and a function pointer (and no generic callback)
// as argument to ensure that we don't leak generic query types into the prepare function
#[allow(unreachable_pub)]
#[allow(clippy::too_many_arguments)] // we need all of them
pub fn cached_statement_non_generic<'a, R, C>(
&'a mut self,
maybe_type_id: Option<TypeId>,
source: &dyn QueryFragmentForCachedStatement<DB>,
backend: &DB,
bind_types: &[DB::TypeMetadata],
conn: C,
prepare_fn: fn(C, &str, PrepareForCache, &[DB::TypeMetadata]) -> R,
instrumentation: &mut dyn Instrumentation,
) -> R::Return<'a>
where
R: StatementCallbackReturnType<Statement, C> + 'a,
{
Self::cached_statement_non_generic_impl(
self.cache.as_mut(),
maybe_type_id,
source,
backend,
bind_types,
conn,
|conn, sql, is_cached| {
if is_cached {
instrumentation.on_connection_event(InstrumentationEvent::CacheQuery { sql });
self.cache_counter += 1;
prepare_fn(
conn,
sql,
PrepareForCache::Yes {
counter: self.cache_counter,
},
bind_types,
)
} else {
prepare_fn(sql, PrepareForCache::No)
prepare_fn(conn, sql, PrepareForCache::No, bind_types)
}
},
)
}

/// Reduce the amount of monomorphized code by factoring this via dynamic dispatch
fn cached_statement_non_generic<'a>(
/// There will be only one instance of `R` for diesel (and a different single instance for diesel-async)
/// There will be only a instance per connection type `C` for each connection that
/// uses this prepared statement impl, this closely correlates to the types `DB` and `Statement`
/// for the overall statement cache impl
fn cached_statement_non_generic_impl<'a, R, C>(
cache: &'a mut dyn StatementCacheStrategy<DB, Statement>,
maybe_type_id: Option<TypeId>,
source: &dyn QueryFragmentForCachedStatement<DB>,
backend: &DB,
bind_types: &[DB::TypeMetadata],
prepare_fn: &mut dyn FnMut(&str, bool) -> QueryResult<Statement>,
) -> QueryResult<MaybeCached<'a, Statement>> {
let cache_key = StatementCacheKey::for_source(maybe_type_id, source, bind_types, backend)?;
if !source.is_safe_to_cache_prepared(backend)? {
let sql = cache_key.sql(source, backend)?;
return prepare_fn(&sql, false).map(MaybeCached::CannotCache);
conn: C,
prepare_fn: impl FnOnce(C, &str, bool) -> R,
) -> R::Return<'a>
where
R: StatementCallbackReturnType<Statement, C> + 'a,
{
// this function cannot use the `?` operator
// as we want to abstract over returning `QueryResult<MaybeCached>` and
// `impl Future<Output = QueryResult<MaybeCached>>` here
// to share the prepared statement cache implementation between diesel and
// diesel_async
//
// For this reason we need to match explicitly on each error and call `R::from_error()`
// to construct the right error return variant
let cache_key =
match StatementCacheKey::for_source(maybe_type_id, source, bind_types, backend) {
Ok(o) => o,
Err(e) => return R::from_error(e),
};
let is_safe_to_cache_prepared = match source.is_safe_to_cache_prepared(backend) {
Ok(o) => o,
Err(e) => return R::from_error(e),
};
// early return if the statement cannot be cached
if !is_safe_to_cache_prepared {
let sql = match cache_key.sql(source, backend) {
Ok(sql) => sql,
Err(e) => return R::from_error(e),
};
return prepare_fn(conn, &sql, false).map_to_no_cache();
}
let entry = cache.lookup_statement(cache_key);
match entry {
// The statement is already cached
LookupStatementResult::CacheEntry(Entry::Occupied(e)) => {
R::map_to_cache(e.into_mut(), conn)
}
// The statement is not cached but there is capacity to cache it
LookupStatementResult::CacheEntry(Entry::Vacant(e)) => {
let sql = match e.key().sql(source, backend) {
Ok(sql) => sql,
Err(e) => return R::from_error(e),
};
let st = prepare_fn(conn, &sql, true);
st.register_cache(|stmt| e.insert(stmt))
}
// The statement is not cached and there is no capacity to cache it
LookupStatementResult::NoCache(cache_key) => {
let sql = match cache_key.sql(source, backend) {
Ok(sql) => sql,
Err(e) => return R::from_error(e),
};
prepare_fn(conn, &sql, false).map_to_no_cache()
}
}
cache.get(cache_key, backend, source, prepare_fn)
}
}

Expand All @@ -266,9 +364,11 @@ where
pub trait QueryFragmentForCachedStatement<DB> {
/// Convert the query fragment into a SQL string for the given backend
fn construct_sql(&self, backend: &DB) -> QueryResult<String>;

/// Check whether it's safe to cache the query
fn is_safe_to_cache_prepared(&self, backend: &DB) -> QueryResult<bool>;
}

impl<T, DB> QueryFragmentForCachedStatement<DB> for T
where
DB: Backend,
Expand Down Expand Up @@ -303,6 +403,71 @@ pub enum MaybeCached<'a, T: 'a> {
Cached(&'a mut T),
}

/// This trait abstracts over the type returned by the prepare statement function
///
/// The main use-case for this abstraction is to share the same statement cache implementation
/// between diesel and diesel-async.
#[cfg_attr(
docsrs,
doc(cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))
)]
#[allow(unreachable_pub)]
pub trait StatementCallbackReturnType<S: 'static, C> {
/// The return type of `StatementCache::cached_statement`
///
/// Either a `QueryResult<MaybeCached<S>>` or a future of that result type
type Return<'a>;

/// Create the return type from an error
fn from_error<'a>(e: diesel::result::Error) -> Self::Return<'a>;

/// Map the callback return type to the `MaybeCached::CannotCache` variant
fn map_to_no_cache<'a>(self) -> Self::Return<'a>
where
Self: 'a;

/// Map the cached statement to the `MaybeCached::Cached` variant
fn map_to_cache(stmt: &mut S, conn: C) -> Self::Return<'_>;

/// Insert the created statement into the cache via the provided callback
/// and then turn the returned reference into `MaybeCached::Cached`
fn register_cache<'a>(
self,
callback: impl FnOnce(S) -> &'a mut S + Send + 'a,
) -> Self::Return<'a>
where
Self: 'a;
}

impl<S, C> StatementCallbackReturnType<S, C> for QueryResult<S>
where
S: 'static,
{
type Return<'a> = QueryResult<MaybeCached<'a, S>>;

fn from_error<'a>(e: diesel::result::Error) -> Self::Return<'a> {
Err(e)
}

fn map_to_no_cache<'a>(self) -> Self::Return<'a> {
self.map(MaybeCached::CannotCache)
}

fn map_to_cache(stmt: &mut S, _conn: C) -> Self::Return<'_> {
Ok(MaybeCached::Cached(stmt))
}

fn register_cache<'a>(
self,
callback: impl FnOnce(S) -> &'a mut S + Send + 'a,
) -> Self::Return<'a>
where
Self: 'a,
{
Ok(MaybeCached::Cached(callback(self?)))
}
}

impl<T> Deref for MaybeCached<'_, T> {
type Target = T;

Expand Down
79 changes: 40 additions & 39 deletions diesel/src/connection/statement_cache/strategy.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,47 @@
use crate::backend::Backend;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::hash::Hash;

use crate::{backend::Backend, result::Error};

use super::{CacheSize, MaybeCached, QueryFragmentForCachedStatement, StatementCacheKey};
use super::{CacheSize, StatementCacheKey};

/// Indicates the cache key status
//
// This is a separate enum and not just `Option<Entry>`
// as we need to return the cache key for owner ship reasons
// if we don't have a cache at all
#[cfg_attr(
feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes",
allow(missing_debug_implementations)
)]
// cannot implement debug easily as StatementCacheKey is not Debug
pub enum LookupStatementResult<'a, DB, Statement>
where
DB: Backend,
{
/// The cache entry, either already populated or vacant
/// in the later case the caller needs to prepare the
/// statement and insert it into the cache
CacheEntry(Entry<'a, StatementCacheKey<DB>, Statement>),
/// This key should not be cached
NoCache(StatementCacheKey<DB>),
}

/// Implement this trait, in order to control statement caching.
#[allow(unreachable_pub)]
pub trait StatementCacheStrategy<DB, Statement>
pub trait StatementCacheStrategy<DB, Statement>: Send + 'static
where
DB: Backend,
StatementCacheKey<DB>: Hash + Eq,
{
/// Returns which prepared statement cache size is implemented by this trait
fn cache_size(&self) -> CacheSize;

/// Every query (which is safe to cache) will go through this function
/// The implementation will decide whether to cache statement or not
/// * `prepare_fn` - will be invoked if prepared statement wasn't cached already
/// * first argument is sql query string
/// * second argument specifies whether statement will be cached (true) or not (false).
fn get(
/// Returns whether or not the corresponding cache key is already cached
fn lookup_statement(
&mut self,
key: StatementCacheKey<DB>,
backend: &DB,
source: &dyn QueryFragmentForCachedStatement<DB>,
prepare_fn: &mut dyn FnMut(&str, bool) -> Result<Statement, Error>,
) -> Result<MaybeCached<'_, Statement>, Error>;
) -> LookupStatementResult<'_, DB, Statement>;
}

/// Cache all (safe) statements for as long as connection is alive.
Expand All @@ -52,27 +66,17 @@ where

impl<DB, Statement> StatementCacheStrategy<DB, Statement> for WithCacheStrategy<DB, Statement>
where
DB: Backend,
DB: Backend + 'static,
StatementCacheKey<DB>: Hash + Eq,
DB::TypeMetadata: Clone,
DB::TypeMetadata: Send + Clone,
DB::QueryBuilder: Default,
Statement: Send + 'static,
{
fn get(
fn lookup_statement(
&mut self,
key: StatementCacheKey<DB>,
backend: &DB,
source: &dyn QueryFragmentForCachedStatement<DB>,
prepare_fn: &mut dyn FnMut(&str, bool) -> Result<Statement, Error>,
) -> Result<MaybeCached<'_, Statement>, Error> {
let entry = self.cache.entry(key);
match entry {
Entry::Occupied(e) => Ok(MaybeCached::Cached(e.into_mut())),
Entry::Vacant(e) => {
let sql = e.key().sql(source, backend)?;
let st = prepare_fn(&sql, true)?;
Ok(MaybeCached::Cached(e.insert(st)))
}
}
entry: StatementCacheKey<DB>,
) -> LookupStatementResult<'_, DB, Statement> {
LookupStatementResult::CacheEntry(self.cache.entry(entry))
}

fn cache_size(&self) -> CacheSize {
Expand All @@ -91,16 +95,13 @@ where
StatementCacheKey<DB>: Hash + Eq,
DB::TypeMetadata: Clone,
DB::QueryBuilder: Default,
Statement: 'static,
{
fn get(
fn lookup_statement(
&mut self,
key: StatementCacheKey<DB>,
backend: &DB,
source: &dyn QueryFragmentForCachedStatement<DB>,
prepare_fn: &mut dyn FnMut(&str, bool) -> Result<Statement, Error>,
) -> Result<MaybeCached<'_, Statement>, Error> {
let sql = key.sql(source, backend)?;
Ok(MaybeCached::CannotCache(prepare_fn(&sql, false)?))
entry: StatementCacheKey<DB>,
) -> LookupStatementResult<'_, DB, Statement> {
LookupStatementResult::NoCache(entry)
}

fn cache_size(&self) -> CacheSize {
Expand Down
Loading

0 comments on commit 25cff1d

Please sign in to comment.