From 5a172546f14d2318002f665417e17294a5fff256 Mon Sep 17 00:00:00 2001 From: emcake <3726783+emcake@users.noreply.github.com> Date: Tue, 30 Jan 2024 14:15:34 +0000 Subject: [PATCH] fix: made generalize_filter less permissive, also added more cases --- crates/core/src/operations/merge/mod.rs | 143 ++++++++++++++++++++---- 1 file changed, 121 insertions(+), 22 deletions(-) diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index ffe2e78e38..c6cc72cc8a 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -52,7 +52,7 @@ use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; use datafusion_expr::expr::Placeholder; use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType}; use datafusion_expr::{ - BinaryExpr, Distinct, Extension, Filter, LogicalPlan, LogicalPlanBuilder, Projection, + BinaryExpr, Distinct, Extension, Filter, LogicalPlan, LogicalPlanBuilder, Operator, Projection, UserDefinedLogicalNode, UNNAMED_TABLE, }; use futures::future::BoxFuture; @@ -699,16 +699,34 @@ fn generalize_filter( target_name: &TableReference, placeholders: &mut HashMap, ) -> Option { - fn references_table(expr: &Expr, table: &TableReference) -> Option { - match expr { + #[derive(Debug)] + enum ReferenceTableCheck { + HasReference(String), + NoReference, + Unknown, + } + impl ReferenceTableCheck { + fn has_reference(&self) -> bool { + match self { + ReferenceTableCheck::HasReference(_) => true, + _ => false, + } + } + } + fn references_table(expr: &Expr, table: &TableReference) -> ReferenceTableCheck { + let res = match expr { Expr::Alias(alias) => references_table(&alias.expr, table), - Expr::Column(col) => col.relation.as_ref().and_then(|rel| { - if rel == table { - Some(col.name.to_owned()) - } else { - None - } - }), + Expr::Column(col) => col + .relation + .as_ref() + .map(|rel| { + if rel == table { + ReferenceTableCheck::HasReference(col.name.to_owned()) + } else { + ReferenceTableCheck::NoReference + } + }) + .unwrap_or(ReferenceTableCheck::NoReference), Expr::Negative(neg) => references_table(neg, table), Expr::Cast(cast) => references_table(&cast.expr, table), Expr::TryCast(try_cast) => references_table(&try_cast.expr, table), @@ -716,17 +734,22 @@ fn generalize_filter( if func.args.len() == 1 { references_table(&func.args[0], table) } else { - None + ReferenceTableCheck::Unknown } } - _ => None, - } + Expr::IsNull(inner) => references_table(&inner, table), + Expr::Literal(_) => ReferenceTableCheck::NoReference, + _ => ReferenceTableCheck::Unknown, + }; + res } match predicate { Expr::BinaryExpr(binary) => { - if references_table(&binary.right, source_name).is_some() { - if let Some(left_target) = references_table(&binary.left, target_name) { + if references_table(&binary.right, source_name).has_reference() { + if let ReferenceTableCheck::HasReference(left_target) = + references_table(&binary.left, target_name) + { if partition_columns.contains(&left_target) { let placeholder_name = format!("{left_target}_{}", placeholders.len()); @@ -747,8 +770,10 @@ fn generalize_filter( } return None; } - if references_table(&binary.left, source_name).is_some() { - if let Some(right_target) = references_table(&binary.right, target_name) { + if references_table(&binary.left, source_name).has_reference() { + if let ReferenceTableCheck::HasReference(right_target) = + references_table(&binary.right, target_name) + { if partition_columns.contains(&right_target) { let placeholder_name = format!("{right_target}_{}", placeholders.len()); @@ -785,19 +810,45 @@ fn generalize_filter( placeholders, ); - match (left, right) { + let res = match (left, right) { (None, None) => None, - (None, Some(r)) => Some(r), - (Some(l), None) => Some(l), + (None, Some(one_side)) | (Some(one_side), None) => { + // in the case of an AND clause, it's safe to generalize the filter down to just one side of the AND. + // this is because this filter will be more permissive than the actual predicate, so we know that + // we will catch all data that could be matched by the predicate. For OR this is not the case - we + // could potentially eliminate one side of the predicate and the filter would only match half the + // cases that would have satisfied the match predicate. + match binary.op { + Operator::And => Some(one_side), + Operator::Or => None, + _ => None, + } + } (Some(l), Some(r)) => Expr::BinaryExpr(BinaryExpr { left: l.into(), op: binary.op, right: r.into(), }) .into(), - } + }; + res } - other => Some(other), + other => match references_table(&other, source_name) { + ReferenceTableCheck::HasReference(col) => { + let placeholder_name = format!("{col}_{}", placeholders.len()); + + let placeholder = Expr::Placeholder(datafusion_expr::expr::Placeholder { + id: placeholder_name.clone(), + data_type: None, + }); + + placeholders.insert(placeholder_name, other); + + Some(placeholder) + } + ReferenceTableCheck::NoReference => Some(other), + ReferenceTableCheck::Unknown => None, + }, } } @@ -1488,6 +1539,7 @@ mod tests { use datafusion_expr::Expr; use datafusion_expr::LogicalPlanBuilder; use datafusion_expr::Operator; + use itertools::Itertools; use serde_json::json; use std::collections::HashMap; use std::ops::Neg; @@ -2434,6 +2486,51 @@ mod tests { assert_eq!(generalized, expected_filter); } + #[tokio::test] + async fn test_generalize_filter_with_partitions_nulls() { + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let source_id = col(Column::new(source.clone().into(), "id")); + let target_id = col(Column::new(target.clone().into(), "id")); + + // source.id = target.id OR (source.id is null and target.id is null) + let parsed_filter = (source_id.clone().eq(target_id.clone())) + .or(source_id.clone().is_null().and(target_id.clone().is_null())); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + // id_1 = target.id OR (id_2 and target.id is null) + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(target_id.clone()) + .or(Expr::Placeholder(Placeholder { + id: "id_1".to_owned(), + data_type: None, + }) + .and(target_id.clone().is_null())); + + assert!(placeholders.len() == 2); + + let captured_expressions = placeholders.values().collect_vec(); + + assert!(captured_expressions.contains(&&source_id)); + assert!(captured_expressions.contains(&&source_id.is_null())); + + assert_eq!(generalized, expected_filter); + } + #[tokio::test] async fn test_generalize_filter_with_partitions_captures_expression() { // Check that when generalizing the filter, the placeholder map captures the expression needed to make the statement the same @@ -2478,6 +2575,7 @@ mod tests { let source = TableReference::parse_str("source"); let target = TableReference::parse_str("target"); + // source.id = target.id and target.id = 'C' let parsed_filter = col(Column::new(source.clone().into(), "id")) .eq(col(Column::new(target.clone().into(), "id"))) .and(col(Column::new(target.clone().into(), "id")).eq(lit("C"))); @@ -2493,6 +2591,7 @@ mod tests { ) .unwrap(); + // id_0 = target.id and target.id = 'C' let expected_filter = Expr::Placeholder(Placeholder { id: "id_0".to_owned(), data_type: None,