Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/expr/src/table_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ pub trait TableSource: Sync + Send {
}

/// Tests whether the table provider can make use of any or all filter expressions
/// to optimise data retrieval.
/// to optimise data retrieval. Only non-volatile expressions are passed to this function.
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
Expand Down
183 changes: 133 additions & 50 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -988,32 +988,46 @@ impl OptimizerRule for PushDownFilter {
LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
LogicalPlan::TableScan(scan) => {
let filter_predicates = split_conjunction(&filter.predicate);
let results = scan

let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
filter_predicates
.into_iter()
.partition(|pred| pred.is_volatile());

// Check which non-volatile filters are supported by source
let supported_filters = scan
.source
.supports_filters_pushdown(filter_predicates.as_slice())?;
if filter_predicates.len() != results.len() {
.supports_filters_pushdown(non_volatile_filters.as_slice())?;
if non_volatile_filters.len() != supported_filters.len() {
return internal_err!(
"Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
results.len(),
filter_predicates.len());
supported_filters.len(),
non_volatile_filters.len());
}

let zip = filter_predicates.into_iter().zip(results);
// Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type
let zip = non_volatile_filters.into_iter().zip(supported_filters);

let new_scan_filters = zip
.clone()
.filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
.map(|(pred, _)| pred);

// Add new scan filters
let new_scan_filters: Vec<Expr> = scan
.filters
.iter()
.chain(new_scan_filters)
.unique()
.cloned()
.collect();

// Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters
let new_predicate: Vec<Expr> = zip
.filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
.map(|(pred, _)| pred.clone())
.map(|(pred, _)| pred)
.chain(volatile_filters)
.cloned()
.collect();

let new_scan = LogicalPlan::TableScan(TableScan {
Expand Down Expand Up @@ -2515,23 +2529,31 @@ mod tests {
}
}

fn table_scan_with_pushdown_provider(
fn table_scan_with_pushdown_provider_builder(
filter_support: TableProviderFilterPushDown,
) -> Result<LogicalPlan> {
filters: Vec<Expr>,
projection: Option<Vec<usize>>,
) -> Result<LogicalPlanBuilder> {
let test_provider = PushDownProvider { filter_support };

let table_scan = LogicalPlan::TableScan(TableScan {
table_name: "test".into(),
filters: vec![],
filters,
projected_schema: Arc::new(DFSchema::try_from(
(*test_provider.schema()).clone(),
)?),
projection: None,
projection,
source: Arc::new(test_provider),
fetch: None,
});

LogicalPlanBuilder::from(table_scan)
Ok(LogicalPlanBuilder::from(table_scan))
}

fn table_scan_with_pushdown_provider(
filter_support: TableProviderFilterPushDown,
) -> Result<LogicalPlan> {
table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
.filter(col("a").eq(lit(1i64)))?
.build()
}
Expand Down Expand Up @@ -2588,25 +2610,14 @@ mod tests {

#[test]
fn multi_combined_filter() -> Result<()> {
let test_provider = PushDownProvider {
filter_support: TableProviderFilterPushDown::Inexact,
};

let table_scan = LogicalPlan::TableScan(TableScan {
table_name: "test".into(),
filters: vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
projected_schema: Arc::new(DFSchema::try_from(
(*test_provider.schema()).clone(),
)?),
projection: Some(vec![0]),
source: Arc::new(test_provider),
fetch: None,
});

let plan = LogicalPlanBuilder::from(table_scan)
.filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
.project(vec![col("a"), col("b")])?
.build()?;
let plan = table_scan_with_pushdown_provider_builder(
TableProviderFilterPushDown::Inexact,
vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
Some(vec![0]),
)?
.filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
.project(vec![col("a"), col("b")])?
.build()?;

let expected = "Projection: a, b\
\n Filter: a = Int64(10) AND b > Int64(11)\
Expand All @@ -2617,25 +2628,14 @@ mod tests {

#[test]
fn multi_combined_filter_exact() -> Result<()> {
let test_provider = PushDownProvider {
filter_support: TableProviderFilterPushDown::Exact,
};

let table_scan = LogicalPlan::TableScan(TableScan {
table_name: "test".into(),
filters: vec![],
projected_schema: Arc::new(DFSchema::try_from(
(*test_provider.schema()).clone(),
)?),
projection: Some(vec![0]),
source: Arc::new(test_provider),
fetch: None,
});

let plan = LogicalPlanBuilder::from(table_scan)
.filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
.project(vec![col("a"), col("b")])?
.build()?;
let plan = table_scan_with_pushdown_provider_builder(
TableProviderFilterPushDown::Exact,
vec![],
Some(vec![0]),
)?
.filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
.project(vec![col("a"), col("b")])?
.build()?;

let expected = r#"
Projection: a, b
Expand Down Expand Up @@ -3385,4 +3385,87 @@ Projection: a, b
\n TableScan: test2";
assert_optimized_plan_eq(plan, expected)
}

#[test]
fn test_push_down_volatile_table_scan() -> Result<()> {
// SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1;
let table_scan = test_table_scan()?;
let fun = ScalarUDF::new_from_impl(TestScalarUDF {
signature: Signature::exact(vec![], Volatility::Volatile),
});
let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.filter(expr.gt(lit(0.1)))?
.build()?;

let expected_before = "Filter: TestScalarUDF() > Float64(0.1)\
\n Projection: test.a, test.b\
\n TableScan: test";
assert_eq!(format!("{plan}"), expected_before);

let expected_after = "Projection: test.a, test.b\
\n Filter: TestScalarUDF() > Float64(0.1)\
\n TableScan: test";
assert_optimized_plan_eq(plan, expected_after)
}

#[test]
fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
// SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
let table_scan = test_table_scan()?;
let fun = ScalarUDF::new_from_impl(TestScalarUDF {
signature: Signature::exact(vec![], Volatility::Volatile),
});
let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.filter(
expr.gt(lit(0.1))
.and(col("t.a").gt(lit(5)))
.and(col("t.b").gt(lit(10))),
)?
.build()?;

let expected_before = "Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)\
\n Projection: test.a, test.b\
\n TableScan: test";
assert_eq!(format!("{plan}"), expected_before);

let expected_after = "Projection: test.a, test.b\
\n Filter: TestScalarUDF() > Float64(0.1)\
\n TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]";
assert_optimized_plan_eq(plan, expected_after)
}

#[test]
fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
// SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
let fun = ScalarUDF::new_from_impl(TestScalarUDF {
signature: Signature::exact(vec![], Volatility::Volatile),
});
let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
let plan = table_scan_with_pushdown_provider_builder(
TableProviderFilterPushDown::Unsupported,
vec![],
None,
)?
.project(vec![col("a"), col("b")])?
.filter(
expr.gt(lit(0.1))
.and(col("t.a").gt(lit(5)))
.and(col("t.b").gt(lit(10))),
)?
.build()?;

let expected_before = "Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)\
\n Projection: a, b\
\n TableScan: test";
assert_eq!(format!("{plan}"), expected_before);

let expected_after = "Projection: a, b\
\n Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)\
\n TableScan: test";
assert_optimized_plan_eq(plan, expected_after)
}
}