diff --git a/scylla/src/lib.rs b/scylla/src/lib.rs index 98608a170..7b942c85c 100644 --- a/scylla/src/lib.rs +++ b/scylla/src/lib.rs @@ -190,7 +190,7 @@ pub use statement::query; pub use frame::response::cql_to_rust; pub use frame::response::cql_to_rust::FromRow; -pub use transport::caching_session::CachingSession; +pub use transport::caching_session::{CachingSession, LegacyCachingSession}; pub use transport::execution_profile::ExecutionProfile; pub use transport::legacy_query_result::LegacyQueryResult; pub use transport::query_result::QueryResult; diff --git a/scylla/src/transport/caching_session.rs b/scylla/src/transport/caching_session.rs index 5e77c48df..e61a2dfe2 100644 --- a/scylla/src/transport/caching_session.rs +++ b/scylla/src/transport/caching_session.rs @@ -5,7 +5,7 @@ use crate::statement::{PagingState, PagingStateResponse}; use crate::transport::errors::QueryError; use crate::transport::iterator::LegacyRowIterator; use crate::transport::partitioner::PartitionerName; -use crate::{LegacyQueryResult, LegacySession}; +use crate::{LegacyQueryResult, QueryResult}; use bytes::Bytes; use dashmap::DashMap; use futures::future::try_join_all; @@ -16,6 +16,11 @@ use std::collections::hash_map::RandomState; use std::hash::BuildHasher; use std::sync::Arc; +use super::iterator::RawIterator; +use super::session::{ + CurrentDeserializationApi, DeserializationApiKind, GenericSession, LegacyDeserializationApi, +}; + /// Contains just the parts of a prepared statement that were returned /// from the database. All remaining parts (query string, page size, /// consistency, etc.) are taken from the Query passed @@ -31,11 +36,12 @@ struct RawPreparedStatementData { /// Provides auto caching while executing queries #[derive(Debug)] -pub struct CachingSession +pub struct GenericCachingSession where S: Clone + BuildHasher, + DeserializationApi: DeserializationApiKind, { - session: LegacySession, + session: GenericSession, /// The prepared statement cache size /// If a prepared statement is added while the limit is reached, the oldest prepared statement /// is removed from the cache @@ -43,11 +49,15 @@ where cache: DashMap, } -impl CachingSession +pub type CachingSession = GenericCachingSession; +pub type LegacyCachingSession = GenericCachingSession; + +impl GenericCachingSession where S: Default + BuildHasher + Clone, + DeserApi: DeserializationApiKind, { - pub fn from(session: LegacySession, cache_size: usize) -> Self { + pub fn from(session: GenericSession, cache_size: usize) -> Self { Self { session, max_capacity: cache_size, @@ -56,20 +66,88 @@ where } } -impl CachingSession +impl GenericCachingSession where S: BuildHasher + Clone, + DeserApi: DeserializationApiKind, { /// Builds a [`CachingSession`] from a [`Session`], a cache size, and a [`BuildHasher`]., /// using a customer hasher. - pub fn with_hasher(session: LegacySession, cache_size: usize, hasher: S) -> Self { + pub fn with_hasher(session: GenericSession, cache_size: usize, hasher: S) -> Self { Self { session, max_capacity: cache_size, cache: DashMap::with_hasher(hasher), } } +} +impl GenericCachingSession +where + S: BuildHasher + Clone, +{ + /// Does the same thing as [`Session::execute_unpaged`] but uses the prepared statement cache + pub async fn execute_unpaged( + &self, + query: impl Into, + values: impl SerializeRow, + ) -> Result { + let query = query.into(); + let prepared = self.add_prepared_statement_owned(query).await?; + self.session.execute_unpaged(&prepared, values).await + } + + /// Does the same thing as [`Session::execute_iter`] but uses the prepared statement cache + pub async fn execute_iter( + &self, + query: impl Into, + values: impl SerializeRow, + ) -> Result { + let query = query.into(); + let prepared = self.add_prepared_statement_owned(query).await?; + self.session.execute_iter(prepared, values).await + } + + /// Does the same thing as [`Session::execute_single_page`] but uses the prepared statement cache + pub async fn execute_single_page( + &self, + query: impl Into, + values: impl SerializeRow, + paging_state: PagingState, + ) -> Result<(QueryResult, PagingStateResponse), QueryError> { + let query = query.into(); + let prepared = self.add_prepared_statement_owned(query).await?; + self.session + .execute_single_page(&prepared, values, paging_state) + .await + } + + /// Does the same thing as [`Session::batch`] but uses the prepared statement cache\ + /// Prepares batch using CachingSession::prepare_batch if needed and then executes it + pub async fn batch( + &self, + batch: &Batch, + values: impl BatchValues, + ) -> Result { + let all_prepared: bool = batch + .statements + .iter() + .all(|stmt| matches!(stmt, BatchStatement::PreparedStatement(_))); + + if all_prepared { + self.session.batch(batch, &values).await + } else { + let prepared_batch: Batch = self.prepare_batch(batch).await?; + + self.session.batch(&prepared_batch, &values).await + } + } +} + +impl GenericCachingSession +where + S: BuildHasher + Clone, +{ /// Does the same thing as [`Session::execute_unpaged`] but uses the prepared statement cache pub async fn execute_unpaged( &self, @@ -126,7 +204,13 @@ where self.session.batch(&prepared_batch, &values).await } } +} +impl GenericCachingSession +where + S: BuildHasher + Clone, + DeserApi: DeserializationApiKind, +{ /// Prepares all statements within the batch and returns a new batch where every /// statement is prepared. /// Uses the prepared statements cache. @@ -212,7 +296,7 @@ where self.max_capacity } - pub fn get_session(&self) -> &LegacySession { + pub fn get_session(&self) -> &GenericSession { &self.session } } @@ -229,7 +313,7 @@ mod tests { use crate::{ batch::{Batch, BatchStatement}, prepared_statement::PreparedStatement, - CachingSession, LegacySession, + LegacyCachingSession, LegacySession, }; use futures::TryStreamExt; use std::collections::BTreeSet; @@ -273,8 +357,8 @@ mod tests { session } - async fn create_caching_session() -> CachingSession { - let session = CachingSession::from(new_for_test(true).await, 2); + async fn create_caching_session() -> LegacyCachingSession { + let session = LegacyCachingSession::from(new_for_test(true).await, 2); // Add a row, this makes it easier to check if the caching works combined with the regular execute fn on Session session @@ -385,7 +469,7 @@ mod tests { } async fn assert_test_batch_table_rows_contain( - sess: &CachingSession, + sess: &LegacyCachingSession, expected_rows: &[(i32, i32)], ) { let selected_rows: BTreeSet<(i32, i32)> = sess @@ -431,18 +515,18 @@ mod tests { } } - let _session: CachingSession = - CachingSession::from(new_for_test(true).await, 2); - let _session: CachingSession = - CachingSession::from(new_for_test(true).await, 2); - let _session: CachingSession = - CachingSession::with_hasher(new_for_test(true).await, 2, Default::default()); + let _session: LegacyCachingSession = + LegacyCachingSession::from(new_for_test(true).await, 2); + let _session: LegacyCachingSession = + LegacyCachingSession::from(new_for_test(true).await, 2); + let _session: LegacyCachingSession = + LegacyCachingSession::with_hasher(new_for_test(true).await, 2, Default::default()); } #[tokio::test] async fn test_batch() { setup_tracing(); - let session: CachingSession = create_caching_session().await; + let session: LegacyCachingSession = create_caching_session().await; session .execute_unpaged( @@ -565,7 +649,8 @@ mod tests { #[tokio::test] async fn test_parameters_caching() { setup_tracing(); - let session: CachingSession = CachingSession::from(new_for_test(true).await, 100); + let session: LegacyCachingSession = + LegacyCachingSession::from(new_for_test(true).await, 100); session .execute_unpaged("CREATE TABLE tbl (a int PRIMARY KEY, b int)", ()) @@ -618,7 +703,8 @@ mod tests { } // This test uses CDC which is not yet compatible with Scylla's tablets. - let session: CachingSession = CachingSession::from(new_for_test(false).await, 100); + let session: LegacyCachingSession = + LegacyCachingSession::from(new_for_test(false).await, 100); session .execute_unpaged( diff --git a/scylla/src/transport/session_test.rs b/scylla/src/transport/session_test.rs index 555800384..0bb685bd0 100644 --- a/scylla/src/transport/session_test.rs +++ b/scylla/src/transport/session_test.rs @@ -20,8 +20,8 @@ use crate::transport::topology::{ use crate::utils::test_utils::{ create_new_session_builder, supports_feature, unique_keyspace_name, }; -use crate::CachingSession; use crate::ExecutionProfile; +use crate::LegacyCachingSession; use crate::LegacyQueryResult; use crate::{LegacySession, SessionBuilder}; use assert_matches::assert_matches; @@ -2012,7 +2012,7 @@ async fn rename(session: &LegacySession, rename_str: &str) { .unwrap(); } -async fn rename_caching(session: &CachingSession, rename_str: &str) { +async fn rename_caching(session: &LegacyCachingSession, rename_str: &str) { session .execute_unpaged(format!("ALTER TABLE tab RENAME {}", rename_str), &()) .await @@ -2230,7 +2230,7 @@ async fn test_unprepared_reprepare_in_caching_session_execute() { session.query_unpaged(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap(); session.use_keyspace(ks, false).await.unwrap(); - let caching_session: CachingSession = CachingSession::from(session, 64); + let caching_session: LegacyCachingSession = LegacyCachingSession::from(session, 64); caching_session .execute_unpaged(