diff --git a/datafusion/expr/src/table_source.rs b/datafusion/expr/src/table_source.rs index bdb602d48dee5..e9a677de50c13 100644 --- a/datafusion/expr/src/table_source.rs +++ b/datafusion/expr/src/table_source.rs @@ -99,7 +99,7 @@ pub trait TableSource: Sync + Send { } /// Tests whether the table provider can make use of any or all filter expressions - /// to optimise data retrieval. + /// to optimise data retrieval. Only non-volatile expressions are passed to this function. fn supports_filters_pushdown( &self, filters: &[&Expr], diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 23cd46803c78d..195dc06578b2b 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -988,22 +988,32 @@ impl OptimizerRule for PushDownFilter { LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)), LogicalPlan::TableScan(scan) => { let filter_predicates = split_conjunction(&filter.predicate); - let results = scan + + let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) = + filter_predicates + .into_iter() + .partition(|pred| pred.is_volatile()); + + // Check which non-volatile filters are supported by source + let supported_filters = scan .source - .supports_filters_pushdown(filter_predicates.as_slice())?; - if filter_predicates.len() != results.len() { + .supports_filters_pushdown(non_volatile_filters.as_slice())?; + if non_volatile_filters.len() != supported_filters.len() { return internal_err!( "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}", - results.len(), - filter_predicates.len()); + supported_filters.len(), + non_volatile_filters.len()); } - let zip = filter_predicates.into_iter().zip(results); + // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type + let zip = non_volatile_filters.into_iter().zip(supported_filters); let new_scan_filters = zip .clone() .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported) .map(|(pred, _)| pred); + + // Add new scan filters let new_scan_filters: Vec = scan .filters .iter() @@ -1011,9 +1021,13 @@ impl OptimizerRule for PushDownFilter { .unique() .cloned() .collect(); + + // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters let new_predicate: Vec = zip .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact) - .map(|(pred, _)| pred.clone()) + .map(|(pred, _)| pred) + .chain(volatile_filters) + .cloned() .collect(); let new_scan = LogicalPlan::TableScan(TableScan { @@ -2515,23 +2529,31 @@ mod tests { } } - fn table_scan_with_pushdown_provider( + fn table_scan_with_pushdown_provider_builder( filter_support: TableProviderFilterPushDown, - ) -> Result { + filters: Vec, + projection: Option>, + ) -> Result { let test_provider = PushDownProvider { filter_support }; let table_scan = LogicalPlan::TableScan(TableScan { table_name: "test".into(), - filters: vec![], + filters, projected_schema: Arc::new(DFSchema::try_from( (*test_provider.schema()).clone(), )?), - projection: None, + projection, source: Arc::new(test_provider), fetch: None, }); - LogicalPlanBuilder::from(table_scan) + Ok(LogicalPlanBuilder::from(table_scan)) + } + + fn table_scan_with_pushdown_provider( + filter_support: TableProviderFilterPushDown, + ) -> Result { + table_scan_with_pushdown_provider_builder(filter_support, vec![], None)? .filter(col("a").eq(lit(1i64)))? .build() } @@ -2588,25 +2610,14 @@ mod tests { #[test] fn multi_combined_filter() -> Result<()> { - let test_provider = PushDownProvider { - filter_support: TableProviderFilterPushDown::Inexact, - }; - - let table_scan = LogicalPlan::TableScan(TableScan { - table_name: "test".into(), - filters: vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))], - projected_schema: Arc::new(DFSchema::try_from( - (*test_provider.schema()).clone(), - )?), - projection: Some(vec![0]), - source: Arc::new(test_provider), - fetch: None, - }); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))? - .project(vec![col("a"), col("b")])? - .build()?; + let plan = table_scan_with_pushdown_provider_builder( + TableProviderFilterPushDown::Inexact, + vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))], + Some(vec![0]), + )? + .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))? + .project(vec![col("a"), col("b")])? + .build()?; let expected = "Projection: a, b\ \n Filter: a = Int64(10) AND b > Int64(11)\ @@ -2617,25 +2628,14 @@ mod tests { #[test] fn multi_combined_filter_exact() -> Result<()> { - let test_provider = PushDownProvider { - filter_support: TableProviderFilterPushDown::Exact, - }; - - let table_scan = LogicalPlan::TableScan(TableScan { - table_name: "test".into(), - filters: vec![], - projected_schema: Arc::new(DFSchema::try_from( - (*test_provider.schema()).clone(), - )?), - projection: Some(vec![0]), - source: Arc::new(test_provider), - fetch: None, - }); - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))? - .project(vec![col("a"), col("b")])? - .build()?; + let plan = table_scan_with_pushdown_provider_builder( + TableProviderFilterPushDown::Exact, + vec![], + Some(vec![0]), + )? + .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))? + .project(vec![col("a"), col("b")])? + .build()?; let expected = r#" Projection: a, b @@ -3385,4 +3385,87 @@ Projection: a, b \n TableScan: test2"; assert_optimized_plan_eq(plan, expected) } + + #[test] + fn test_push_down_volatile_table_scan() -> Result<()> { + // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1; + let table_scan = test_table_scan()?; + let fun = ScalarUDF::new_from_impl(TestScalarUDF { + signature: Signature::exact(vec![], Volatility::Volatile), + }); + let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .filter(expr.gt(lit(0.1)))? + .build()?; + + let expected_before = "Filter: TestScalarUDF() > Float64(0.1)\ + \n Projection: test.a, test.b\ + \n TableScan: test"; + assert_eq!(format!("{plan}"), expected_before); + + let expected_after = "Projection: test.a, test.b\ + \n Filter: TestScalarUDF() > Float64(0.1)\ + \n TableScan: test"; + assert_optimized_plan_eq(plan, expected_after) + } + + #[test] + fn test_push_down_volatile_mixed_table_scan() -> Result<()> { + // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10; + let table_scan = test_table_scan()?; + let fun = ScalarUDF::new_from_impl(TestScalarUDF { + signature: Signature::exact(vec![], Volatility::Volatile), + }); + let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .filter( + expr.gt(lit(0.1)) + .and(col("t.a").gt(lit(5))) + .and(col("t.b").gt(lit(10))), + )? + .build()?; + + let expected_before = "Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)\ + \n Projection: test.a, test.b\ + \n TableScan: test"; + assert_eq!(format!("{plan}"), expected_before); + + let expected_after = "Projection: test.a, test.b\ + \n Filter: TestScalarUDF() > Float64(0.1)\ + \n TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]"; + assert_optimized_plan_eq(plan, expected_after) + } + + #[test] + fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> { + // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10; + let fun = ScalarUDF::new_from_impl(TestScalarUDF { + signature: Signature::exact(vec![], Volatility::Volatile), + }); + let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); + let plan = table_scan_with_pushdown_provider_builder( + TableProviderFilterPushDown::Unsupported, + vec![], + None, + )? + .project(vec![col("a"), col("b")])? + .filter( + expr.gt(lit(0.1)) + .and(col("t.a").gt(lit(5))) + .and(col("t.b").gt(lit(10))), + )? + .build()?; + + let expected_before = "Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)\ + \n Projection: a, b\ + \n TableScan: test"; + assert_eq!(format!("{plan}"), expected_before); + + let expected_after = "Projection: a, b\ + \n Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)\ + \n TableScan: test"; + assert_optimized_plan_eq(plan, expected_after) + } }