diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 08ca1a176e573..8e52ca1bf1af6 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2330,6 +2330,10 @@ impl DataFrame { /// Cache DataFrame as a memory table. /// + /// Default behavior could be changed using + /// a [`crate::execution::session_state::CacheFactory`] + /// configured via [`SessionState`]. + /// /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; @@ -2344,14 +2348,20 @@ impl DataFrame { /// # } /// ``` pub async fn cache(self) -> Result { - let context = SessionContext::new_with_state((*self.session_state).clone()); - // The schema is consistent with the output - let plan = self.clone().create_physical_plan().await?; - let schema = plan.schema(); - let task_ctx = Arc::new(self.task_ctx()); - let partitions = collect_partitioned(plan, task_ctx).await?; - let mem_table = MemTable::try_new(schema, partitions)?; - context.read_table(Arc::new(mem_table)) + if let Some(cache_factory) = self.session_state.cache_factory() { + let new_plan = + cache_factory.create(self.plan, self.session_state.as_ref())?; + Ok(Self::new(*self.session_state, new_plan)) + } else { + let context = SessionContext::new_with_state((*self.session_state).clone()); + // The schema is consistent with the output + let plan = self.clone().create_physical_plan().await?; + let schema = plan.schema(); + let task_ctx = Arc::new(self.task_ctx()); + let partitions = collect_partitioned(plan, task_ctx).await?; + let mem_table = MemTable::try_new(schema, partitions)?; + context.read_table(Arc::new(mem_table)) + } } /// Apply an alias to the DataFrame. diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index d7a66db28ac47..8bc526fbfb8c6 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -185,6 +185,7 @@ pub struct SessionState { /// It will be invoked on `CREATE FUNCTION` statements. /// thus, changing dialect o PostgreSql is required function_factory: Option>, + cache_factory: Option>, /// Cache logical plans of prepared statements for later execution. /// Key is the prepared statement name. prepared_plans: HashMap>, @@ -206,6 +207,7 @@ impl Debug for SessionState { .field("table_options", &self.table_options) .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) + .field("cache_factory", &self.cache_factory) .field("expr_planners", &self.expr_planners); #[cfg(feature = "sql")] @@ -355,6 +357,16 @@ impl SessionState { self.function_factory.as_ref() } + /// Register a [`CacheFactory`] for custom caching strategy + pub fn set_cache_factory(&mut self, cache_factory: Arc) { + self.cache_factory = Some(cache_factory); + } + + /// Get the cache factory + pub fn cache_factory(&self) -> Option<&Arc> { + self.cache_factory.as_ref() + } + /// Get the table factories pub fn table_factories(&self) -> &HashMap> { &self.table_factories @@ -941,6 +953,7 @@ pub struct SessionStateBuilder { table_factories: Option>>, runtime_env: Option>, function_factory: Option>, + cache_factory: Option>, // fields to support convenience functions analyzer_rules: Option>>, optimizer_rules: Option>>, @@ -978,6 +991,7 @@ impl SessionStateBuilder { table_factories: None, runtime_env: None, function_factory: None, + cache_factory: None, // fields to support convenience functions analyzer_rules: None, optimizer_rules: None, @@ -1030,7 +1044,7 @@ impl SessionStateBuilder { table_factories: Some(existing.table_factories), runtime_env: Some(existing.runtime_env), function_factory: existing.function_factory, - + cache_factory: existing.cache_factory, // fields to support convenience functions analyzer_rules: None, optimizer_rules: None, @@ -1319,6 +1333,15 @@ impl SessionStateBuilder { self } + /// Set a [`CacheFactory`] for custom caching strategy + pub fn with_cache_factory( + mut self, + cache_factory: Option>, + ) -> Self { + self.cache_factory = cache_factory; + self + } + /// Register an `ObjectStore` to the [`RuntimeEnv`]. See [`RuntimeEnv::register_object_store`] /// for more details. /// @@ -1382,6 +1405,7 @@ impl SessionStateBuilder { table_factories, runtime_env, function_factory, + cache_factory, analyzer_rules, optimizer_rules, physical_optimizer_rules, @@ -1418,6 +1442,7 @@ impl SessionStateBuilder { table_factories: table_factories.unwrap_or_default(), runtime_env, function_factory, + cache_factory, prepared_plans: HashMap::new(), }; @@ -1621,6 +1646,11 @@ impl SessionStateBuilder { &mut self.function_factory } + /// Returns the cache factory + pub fn cache_factory(&mut self) -> &mut Option> { + &mut self.cache_factory + } + /// Returns the current analyzer_rules value pub fn analyzer_rules( &mut self, @@ -1659,6 +1689,7 @@ impl Debug for SessionStateBuilder { .field("table_options", &self.table_options) .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) + .field("cache_factory", &self.cache_factory) .field("expr_planners", &self.expr_planners); #[cfg(feature = "sql")] let ret = ret.field("type_planner", &self.type_planner); @@ -2047,6 +2078,19 @@ pub(crate) struct PreparedPlan { pub(crate) plan: Arc, } +/// A [`CacheFactory`] can be registered via [`SessionState`] +/// to create a custom logical plan for [`crate::dataframe::DataFrame::cache`]. +/// Additionally, a custom [`crate::physical_planner::ExtensionPlanner`]/[`QueryPlanner`] +/// may need to be implemented to handle such plans. +pub trait CacheFactory: Debug + Send + Sync { + /// Create a logical plan for caching + fn create( + &self, + plan: LogicalPlan, + session_state: &SessionState, + ) -> datafusion_common::Result; +} + #[cfg(test)] mod tests { use super::{SessionContextProvider, SessionStateBuilder}; diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index 7149c5b0bd8ca..466ee38a426fd 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -25,6 +25,7 @@ pub mod csv; use futures::Stream; use std::any::Any; use std::collections::HashMap; +use std::fmt::Formatter; use std::fs::File; use std::io::Write; use std::path::Path; @@ -36,16 +37,20 @@ use crate::dataframe::DataFrame; use crate::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use crate::datasource::{empty::EmptyTable, provider_as_source}; use crate::error::Result; +use crate::execution::session_state::CacheFactory; use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::ExecutionPlan; use crate::prelude::{CsvReadOptions, SessionContext}; -use crate::execution::SendableRecordBatchStream; +use crate::execution::{SendableRecordBatchStream, SessionState, SessionStateBuilder}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_catalog::Session; -use datafusion_common::TableReference; -use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; +use datafusion_common::{DFSchemaRef, TableReference}; +use datafusion_expr::{ + CreateExternalTable, Expr, LogicalPlan, SortExpr, TableType, + UserDefinedLogicalNodeCore, +}; use std::pin::Pin; use async_trait::async_trait; @@ -282,3 +287,67 @@ impl RecordBatchStream for BoundedStream { self.record_batch.schema() } } + +#[derive(Hash, Eq, PartialEq, PartialOrd, Debug)] +struct CacheNode { + input: LogicalPlan, +} + +impl UserDefinedLogicalNodeCore for CacheNode { + fn name(&self) -> &str { + "CacheNode" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "CacheNode") + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + assert_eq!(inputs.len(), 1, "input size inconsistent"); + Ok(Self { + input: inputs[0].clone(), + }) + } +} + +#[derive(Debug)] +struct TestCacheFactory {} + +impl CacheFactory for TestCacheFactory { + fn create( + &self, + plan: LogicalPlan, + _session_state: &SessionState, + ) -> Result { + Ok(LogicalPlan::Extension(datafusion_expr::Extension { + node: Arc::new(CacheNode { input: plan }), + })) + } +} + +/// Create a test table registered to a session context with an associated cache factory +pub async fn test_table_with_cache_factory() -> Result { + let session_state = SessionStateBuilder::new() + .with_cache_factory(Some(Arc::new(TestCacheFactory {}))) + .build(); + let ctx = SessionContext::new_with_state(session_state); + let name = "aggregate_test_100"; + register_aggregate_csv(&ctx, name).await?; + ctx.table(name).await +} diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index b856a776c864a..fb6dc3bcba901 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -61,7 +61,7 @@ use datafusion::prelude::{ }; use datafusion::test_util::{ parquet_test_data, populate_csv_partitions, register_aggregate_csv, test_table, - test_table_with_name, + test_table_with_cache_factory, test_table_with_name, }; use datafusion_catalog::TableProvider; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; @@ -2335,6 +2335,29 @@ async fn cache_test() -> Result<()> { Ok(()) } +#[tokio::test] +async fn cache_producer_test() -> Result<()> { + let df = test_table_with_cache_factory() + .await? + .select_columns(&["c2", "c3"])? + .limit(0, Some(1))? + .with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?; + + let cached_df = df.clone().cache().await?; + + assert_snapshot!( + cached_df.clone().into_optimized_plan().unwrap(), + @r###" + CacheNode + Projection: aggregate_test_100.c2, aggregate_test_100.c3, CAST(CAST(aggregate_test_100.c2 AS Int64) + CAST(aggregate_test_100.c3 AS Int64) AS Int64) AS sum + Projection: aggregate_test_100.c2, aggregate_test_100.c3 + Limit: skip=0, fetch=1 + TableScan: aggregate_test_100, fetch=1 + "### + ); + Ok(()) +} + #[tokio::test] async fn partition_aware_union() -> Result<()> { let left = test_table().await?.select_columns(&["c1", "c2"])?;