diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 55bff5849c5cb..6d62fbc38574d 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1045,6 +1045,22 @@ impl TreeNodeRewriter for Simplifier<'_> { ); } } + // A = L1 AND A != L2 --> A = L1 (when L1 != L2) + Expr::BinaryExpr(BinaryExpr { + left, + op: And, + right, + }) if is_eq_and_ne_with_different_literal(&left, &right) => { + Transformed::yes(*left) + } + // A != L2 AND A = L1 --> A = L1 (when L1 != L2) + Expr::BinaryExpr(BinaryExpr { + left, + op: And, + right, + }) if is_eq_and_ne_with_different_literal(&right, &left) => { + Transformed::yes(*right) + } // // Rules for Multiply @@ -2398,6 +2414,27 @@ mod tests { assert_eq!(simplify(expr_b), expected); } + #[test] + fn test_simplify_eq_and_neq_with_different_literals() { + // A = 1 AND A != 0 --> A = 1 (when 1 != 0) + let expr = col("c2").eq(lit(1)).and(col("c2").not_eq(lit(0))); + let expected = col("c2").eq(lit(1)); + assert_eq!(simplify(expr), expected); + + // A != 0 AND A = 1 --> A = 1 (when 1 != 0) + let expr = col("c2").not_eq(lit(0)).and(col("c2").eq(lit(1))); + let expected = col("c2").eq(lit(1)); + assert_eq!(simplify(expr), expected); + + // Should NOT simplify when literals are the same (A = 1 AND A != 1) + // This is a contradiction but handled by other rules + let expr = col("c2").eq(lit(1)).and(col("c2").not_eq(lit(1))); + // Should not be simplified by this rule (left unchanged or handled elsewhere) + let result = simplify(expr.clone()); + // The expression should not have been simplified + assert_eq!(result, expr); + } + #[test] fn test_simplify_multiply_by_one() { let expr_a = col("c2") * lit(1); diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 1f214e3d365c9..b0908b47602f7 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -290,6 +290,54 @@ pub fn is_lit(expr: &Expr) -> bool { matches!(expr, Expr::Literal(_, _)) } +/// Checks if `eq_expr` is `A = L1` and `ne_expr` is `A != L2` where L1 != L2. +/// This pattern can be simplified to just `A = L1` since if A equals L1 +/// and L1 is different from L2, then A is automatically not equal to L2. +pub fn is_eq_and_ne_with_different_literal(eq_expr: &Expr, ne_expr: &Expr) -> bool { + fn extract_var_and_literal(expr: &Expr) -> Option<(&Expr, &Expr)> { + match expr { + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) + | Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::NotEq, + right, + }) => match (left.as_ref(), right.as_ref()) { + (Expr::Literal(_, _), var) => Some((var, left)), + (var, Expr::Literal(_, _)) => Some((var, right)), + _ => None, + }, + _ => None, + } + } + match (eq_expr, ne_expr) { + ( + Expr::BinaryExpr(BinaryExpr { + op: Operator::Eq, .. + }), + Expr::BinaryExpr(BinaryExpr { + op: Operator::NotEq, + .. + }), + ) => { + // Check if both compare the same expression against different literals + if let (Some((var1, lit1)), Some((var2, lit2))) = ( + extract_var_and_literal(eq_expr), + extract_var_and_literal(ne_expr), + ) && var1 == var2 + && lit1 != lit2 + { + return true; + } + false + } + _ => false, + } +} + /// negate a Not clause /// input is the clause to be negated.(args of Not clause) /// For BinaryExpr, use the negation of op instead. diff --git a/datafusion/sqllogictest/test_files/join.slt.part b/datafusion/sqllogictest/test_files/join.slt.part index 5d111374ac8cf..c0a838c97d552 100644 --- a/datafusion/sqllogictest/test_files/join.slt.part +++ b/datafusion/sqllogictest/test_files/join.slt.part @@ -973,19 +973,19 @@ ON e.emp_id = d.emp_id WHERE ((dept_name != 'Engineering' AND e.name = 'Alice') OR (name != 'Alice' AND e.name = 'Carol')); ---- logical_plan -01)Filter: d.dept_name != Utf8View("Engineering") AND e.name = Utf8View("Alice") OR e.name != Utf8View("Alice") AND e.name = Utf8View("Carol") +01)Filter: d.dept_name != Utf8View("Engineering") AND e.name = Utf8View("Alice") OR e.name = Utf8View("Carol") 02)--Projection: e.emp_id, e.name, d.dept_name 03)----Left Join: e.emp_id = d.emp_id 04)------SubqueryAlias: e -05)--------Filter: employees.name = Utf8View("Alice") OR employees.name != Utf8View("Alice") AND employees.name = Utf8View("Carol") +05)--------Filter: employees.name = Utf8View("Alice") OR employees.name = Utf8View("Carol") 06)----------TableScan: employees projection=[emp_id, name] 07)------SubqueryAlias: d 08)--------TableScan: department projection=[emp_id, dept_name] physical_plan -01)FilterExec: dept_name@2 != Engineering AND name@1 = Alice OR name@1 != Alice AND name@1 = Carol +01)FilterExec: dept_name@2 != Engineering AND name@1 = Alice OR name@1 = Carol 02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 03)----HashJoinExec: mode=CollectLeft, join_type=Left, on=[(emp_id@0, emp_id@0)], projection=[emp_id@0, name@1, dept_name@3] -04)------FilterExec: name@1 = Alice OR name@1 != Alice AND name@1 = Carol +04)------FilterExec: name@1 = Alice OR name@1 = Carol 05)--------DataSourceExec: partitions=1, partition_sizes=[1] 06)------DataSourceExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/simplify_expr.slt b/datafusion/sqllogictest/test_files/simplify_expr.slt index d8c25ab25e8ea..99fc9900ef619 100644 --- a/datafusion/sqllogictest/test_files/simplify_expr.slt +++ b/datafusion/sqllogictest/test_files/simplify_expr.slt @@ -113,3 +113,21 @@ logical_plan physical_plan 01)ProjectionExec: expr=[[{x:100}] as a] 02)--PlaceholderRowExec + +# Simplify expr = L1 AND expr != L2 to expr = L1 when L1 != L2 +query TT +EXPLAIN SELECT + v = 1 AND v != 0 as opt1, + v = 2 AND v != 2 as noopt1, + v != 3 AND v = 4 as opt2, + v != 5 AND v = 5 as noopt2 +FROM (VALUES (0), (1), (2)) t(v) +---- +logical_plan +01)Projection: t.v = Int64(1) AS opt1, t.v = Int64(2) AND t.v != Int64(2) AS noopt1, t.v = Int64(4) AS opt2, t.v != Int64(5) AND t.v = Int64(5) AS noopt2 +02)--SubqueryAlias: t +03)----Projection: column1 AS v +04)------Values: (Int64(0)), (Int64(1)), (Int64(2)) +physical_plan +01)ProjectionExec: expr=[column1@0 = 1 as opt1, column1@0 = 2 AND column1@0 != 2 as noopt1, column1@0 = 4 as opt2, column1@0 != 5 AND column1@0 = 5 as noopt2] +02)--DataSourceExec: partitions=1, partition_sizes=[1]