Skip to content

Commit e2faf1a

Browse files
wangyumsunchaocloud-fan
authored andcommitted
[CARMEL-4286][CARMEL-4380] Backport SPARK-33910 Simplify/Optimize conditional expressions (#291)
* [SPARK-32721][SQL] Simplify if clauses with null and boolean ### What changes were proposed in this pull request? The following if clause: ```sql if(p, null, false) ``` can be simplified to: ```sql and(p, null) ``` Similarly, the clause: ```sql if(p, null, true) ``` can be simplified to ```sql or(not(p), null) ``` iff the predicate `p` is non-nullable, i.e., can be evaluated to either true or false, but not null. ### Why are the changes needed? Converting if to or/and clauses can better push filters down. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. Closes #29567 from sunchao/SPARK-32721. Authored-by: Chao Sun <sunchao@apache.org> Signed-off-by: DB Tsai <d_tsai@apple.com> (cherry picked from commit 1453a09) * [SPARK-32721][SQL][FOLLOWUP] Simplify if clauses with null and boolean ### What changes were proposed in this pull request? This is a follow-up on SPARK-32721 and PR #29567. In the previous PR we missed two more cases that can be optimized: ``` if(p, false, null) ==> and(not(p), null) if(p, true, null) ==> or(p, null) ``` ### Why are the changes needed? By transforming if to boolean conjunctions or disjunctions, we can enable more filter pushdown to datasources. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit tests. Closes #29603 from sunchao/SPARK-32721-2. Authored-by: Chao Sun <sunchao@apache.org> Signed-off-by: DB Tsai <d_tsai@apple.com> (cherry picked from commit 94d313b) * [SPARK-33798][SQL] Add new rule to push down the foldable expressions through CaseWhen/If ### What changes were proposed in this pull request? This pr add a new rule(`PushFoldableIntoBranches`) to push down the foldable expressions through `CaseWhen/If`. This is a real case from production: ```sql create table t1 using parquet as select * from range(100); create table t2 using parquet as select * from range(200); create temp view v1 as select 'a' as event_type, * from t1 union all select CASE WHEN id = 1 THEN 'b' WHEN id = 3 THEN 'c' end as event_type, * from t2 explain select * from v1 where event_type = 'a'; ``` Before this PR: ``` == Physical Plan == Union :- *(1) Project [a AS event_type#30533, id#30535L] : +- *(1) ColumnarToRow : +- FileScan parquet default.t1[id#30535L] Batched: true, DataFilters: [], Format: Parquet +- *(2) Project [CASE WHEN (id#30536L = 1) THEN b WHEN (id#30536L = 3) THEN c END AS event_type#30534, id#30536L] +- *(2) Filter (CASE WHEN (id#30536L = 1) THEN b WHEN (id#30536L = 3) THEN c END = a) +- *(2) ColumnarToRow +- FileScan parquet default.t2[id#30536L] Batched: true, DataFilters: [(CASE WHEN (id#30536L = 1) THEN b WHEN (id#30536L = 3) THEN c END = a)], Format: Parquet ``` After this PR: ``` == Physical Plan == *(1) Project [a AS event_type#8, id#4L] +- *(1) ColumnarToRow +- FileScan parquet default.t1[id#4L] Batched: true, DataFilters: [], Format: Parquet ``` ### Why are the changes needed? Improve query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #30790 from wangyum/SPARK-33798. Authored-by: Yuming Wang <yumwang@ebay.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit 06b1bbb) * [SPARK-33845][SQL] Remove unnecessary if when trueValue and falseValue are foldable boolean types ### What changes were proposed in this pull request? Improve `SimplifyConditionals`.    Simplify `If(cond, TrueLiteral, FalseLiteral)` to `cond`.    Simplify `If(cond, FalseLiteral, TrueLiteral)` to `Not(cond)`. The use case is: ```sql create table t1 using parquet as select id from range(10); select if (id > 2, false, true) from t1; ``` Before this pr: ``` == Physical Plan == *(1) Project [if ((id#1L > 2)) false else true AS (IF((id > CAST(2 AS BIGINT)), false, true))#2] +- *(1) ColumnarToRow    +- FileScan parquet default.t1[id#1L] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/opensource/spark/spark-warehouse/org.apache.spark.sql.DataF..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:bigint> ``` After this pr: ``` == Physical Plan == *(1) Project [(id#1L <= 2) AS (IF((id > CAST(2 AS BIGINT)), false, true))#2] +- *(1) ColumnarToRow    +- FileScan parquet default.t1[id#1L] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/opensource/spark/spark-warehouse/org.apache.spark.sql.DataF..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:bigint> ``` ### Why are the changes needed? Improve query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #30849 from wangyum/SPARK-33798-2. Authored-by: Yuming Wang <yumwang@ebay.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com> * [SPARK-33848][SQL] Push the UnaryExpression into (if / case) branches ### What changes were proposed in this pull request? This pr push the `UnaryExpression` into (if / case) branches. The use case is: ```sql create table t1 using parquet as select id from range(10); explain select id from t1 where (CASE WHEN id = 1 THEN '1' WHEN id = 3 THEN '2' end) > 3; ``` Before this pr: ``` == Physical Plan == *(1) Filter (cast(CASE WHEN (id#1L = 1) THEN 1 WHEN (id#1L = 3) THEN 2 END as int) > 3) +- *(1) ColumnarToRow +- FileScan parquet default.t1[id#1L] Batched: true, DataFilters: [(cast(CASE WHEN (id#1L = 1) THEN 1 WHEN (id#1L = 3) THEN 2 END as int) > 3)], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/opensource/spark/spark-warehouse/org.apache.spark.sql.DataF..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:bigint> ``` After this pr: ``` == Physical Plan == LocalTableScan <empty>, [id#1L] ``` This change can also improve this case: https://github.com/apache/spark/blob/a78d6ce376edf2a8836e01f47b9dff5371058d4c/sql/core/src/test/resources/tpcds/q62.sql#L5-L22 ### Why are the changes needed? Improve query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #30853 from wangyum/SPARK-33848. Authored-by: Yuming Wang <yumwang@ebay.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com> (cherry picked from commit 1c77605) * [SPARK-33847][SQL] Simplify CaseWhen if elseValue is None ### What changes were proposed in this pull request? 1. Enhance `ReplaceNullWithFalseInPredicate` to replace None of elseValue inside `CaseWhen` with `FalseLiteral` if all branches are `FalseLiteral` . The use case is: ```sql create table t1 using parquet as select id from range(10); explain select id from t1 where (CASE WHEN id = 1 THEN 'a' WHEN id = 3 THEN 'b' end) = 'c'; ``` Before this pr: ``` == Physical Plan == *(1) Filter CASE WHEN (id#1L = 1) THEN false WHEN (id#1L = 3) THEN false END +- *(1) ColumnarToRow +- FileScan parquet default.t1[id#1L] Batched: true, DataFilters: [CASE WHEN (id#1L = 1) THEN false WHEN (id#1L = 3) THEN false END], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/opensource/spark/spark-warehouse/org.apache.spark.sql.DataF..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:bigint> ``` After this pr: ``` == Physical Plan == LocalTableScan <empty>, [id#1L] ``` 2. Enhance `SimplifyConditionals` if elseValue is None and all outputs are null. ### Why are the changes needed? Improve query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #30852 from wangyum/SPARK-33847. Authored-by: Yuming Wang <yumwang@ebay.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit 7ffcfcf) * [SPARK-33861][SQL] Simplify conditional in predicate ### What changes were proposed in this pull request? This pr simplify conditional in predicate, after this change we can push down the filter to datasource: Expression | After simplify -- | -- IF(cond, trueVal, false) | AND(cond, trueVal) IF(cond, trueVal, true) | OR(NOT(cond), trueVal) IF(cond, false, falseVal) | AND(NOT(cond), elseVal) IF(cond, true, falseVal) | OR(cond, elseVal) CASE WHEN cond THEN trueVal ELSE false END | AND(cond, trueVal) CASE WHEN cond THEN trueVal END | AND(cond, trueVal) CASE WHEN cond THEN trueVal ELSE null END | AND(cond, trueVal) CASE WHEN cond THEN trueVal ELSE true END | OR(NOT(cond), trueVal) CASE WHEN cond THEN false ELSE elseVal END | AND(NOT(cond), elseVal) CASE WHEN cond THEN false END | false CASE WHEN cond THEN true ELSE elseVal END | OR(cond, elseVal) CASE WHEN cond THEN true END | cond ### Why are the changes needed? Improve query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #30865 from wangyum/SPARK-33861. Authored-by: Yuming Wang <yumwang@ebay.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit 32d4a2b) * Fix * [SPARK-33845][SQL][FOLLOWUP] fix SimplifyConditionals ### What changes were proposed in this pull request? This is a followup of #30849, to fix a correctness issue caused by null value handling. ### Why are the changes needed? Fix a correctness issue. `If(null, true, false)` should return false, not true. ### Does this PR introduce _any_ user-facing change? Yes, but the bug only exist in the master branch. ### How was this patch tested? updated tests. Closes #30953 from cloud-fan/bug. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com> (cherry picked from commit c2eac1d) * Fix * [SPARK-33884][SQL] Simplify CaseWhenclauses with (true and false) and (false and true) ### What changes were proposed in this pull request? This pr simplify `CaseWhen`clauses with (true and false) and (false and true): Expression | cond.nullable | After simplify -- | -- | -- case when cond then true else false end | true | cond <=> true case when cond then true else false end | false | cond case when cond then false else true end | true | !(cond <=> true) case when cond then false else true end | false | !cond ### Why are the changes needed? Improve query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #30898 from wangyum/SPARK-33884. Authored-by: Yuming Wang <yumwang@ebay.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit f7bdea3) * [SPARK-33848][SQL][FOLLOWUP] Introduce allowList for push into (if / case) branches ### What changes were proposed in this pull request? Introduce allowList push into (if / case) branches to fix potential bug. ### Why are the changes needed? Fix potential bug. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing test. Closes #30955 from wangyum/SPARK-33848-2. Authored-by: Yuming Wang <yumwang@ebay.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit 872107f) * Fix * [SPARK-33847][SQL][FOLLOWUP] Remove the CaseWhen should consider deterministic ### What changes were proposed in this pull request? This pr fix remove the `CaseWhen` if elseValue is empty and other outputs are null because of we should consider deterministic. ### Why are the changes needed? Fix bug. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #30960 from wangyum/SPARK-33847-2. Authored-by: Yuming Wang <yumwang@ebay.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit c425024) Co-authored-by: Chao Sun <sunchao@apache.org> Co-authored-by: Wenchen Fan <wenchen@databricks.com>
1 parent 7811c34 commit e2faf1a

File tree

11 files changed

+879
-40
lines changed

11 files changed

+879
-40
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,12 @@ abstract class UnaryExpression extends Expression {
521521
}
522522
}
523523

524+
525+
object UnaryExpression {
526+
def unapply(e: UnaryExpression): Option[Expression] = Some(e.child)
527+
}
528+
529+
524530
/**
525531
* An expression with two inputs and one output. The output is by default evaluated to null
526532
* if any input is evaluated to null.
@@ -621,6 +627,11 @@ abstract class BinaryExpression extends Expression {
621627
}
622628

623629

630+
object BinaryExpression {
631+
def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right))
632+
}
633+
634+
624635
/**
625636
* A [[BinaryExpression]] that is an operator, with two properties:
626637
*

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,11 @@ abstract class Optimizer(catalogManager: CatalogManager)
9696
LikeSimplification,
9797
BooleanSimplification,
9898
SimplifyConditionals,
99+
PushFoldableIntoBranches,
99100
RemoveDispensableExpressions,
100101
SimplifyBinaryComparison,
101102
ReplaceNullWithFalseInPredicate,
103+
SimplifyConditionalsInPredicate,
102104
PruneFilters,
103105
SimplifyCasts,
104106
SimplifyCaseConversionExpressions,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,10 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20-
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If}
21-
import org.apache.spark.sql.catalyst.expressions.{LambdaFunction, Literal, MapFilter, Or}
22-
import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
20+
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, EqualNullSafe, Expression, If, LambdaFunction, Literal, MapFilter, Or}
21+
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
2322
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan}
2423
import org.apache.spark.sql.catalyst.rules.Rule
25-
import org.apache.spark.sql.internal.SQLConf
2624
import org.apache.spark.sql.types.BooleanType
2725
import org.apache.spark.util.Utils
2826

@@ -55,6 +53,12 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
5553
case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
5654
case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond)))
5755
case p: LogicalPlan => p transformExpressions {
56+
// For `EqualNullSafe` with a `TrueLiteral`, whether the other side is null or false has no
57+
// difference, as `null <=> true` and `false <=> true` both return false.
58+
case EqualNullSafe(left, TrueLiteral) =>
59+
EqualNullSafe(replaceNullWithFalse(left), TrueLiteral)
60+
case EqualNullSafe(TrueLiteral, right) =>
61+
EqualNullSafe(TrueLiteral, replaceNullWithFalse(right))
5862
case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred))
5963
case cw @ CaseWhen(branches, _) =>
6064
val newBranches = branches.map { case (cond, value) =>
@@ -92,7 +96,7 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
9296
val newBranches = cw.branches.map { case (cond, value) =>
9397
replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
9498
}
95-
val newElseValue = cw.elseValue.map(replaceNullWithFalse)
99+
val newElseValue = cw.elseValue.map(replaceNullWithFalse).getOrElse(FalseLiteral)
96100
CaseWhen(newBranches, newElseValue)
97101
case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
98102
If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal))
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, If, Literal, Not, Or}
21+
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
22+
import org.apache.spark.sql.catalyst.plans.logical._
23+
import org.apache.spark.sql.catalyst.rules.Rule
24+
import org.apache.spark.sql.types.BooleanType
25+
26+
/**
27+
* A rule that converts conditional expressions to predicate expressions, if possible, in the
28+
* search condition of the WHERE/HAVING/ON(JOIN) clauses, which contain an implicit Boolean operator
29+
* "(search condition) = TRUE". After this converting, we can potentially push the filter down to
30+
* the data source.
31+
*
32+
* Supported cases are:
33+
* - IF(cond, trueVal, false) => AND(cond, trueVal)
34+
* - IF(cond, trueVal, true) => OR(NOT(cond), trueVal)
35+
* - IF(cond, false, falseVal) => AND(NOT(cond), elseVal)
36+
* - IF(cond, true, falseVal) => OR(cond, elseVal)
37+
* - CASE WHEN cond THEN trueVal ELSE false END => AND(cond, trueVal)
38+
* - CASE WHEN cond THEN trueVal END => AND(cond, trueVal)
39+
* - CASE WHEN cond THEN trueVal ELSE null END => AND(cond, trueVal)
40+
* - CASE WHEN cond THEN trueVal ELSE true END => OR(NOT(cond), trueVal)
41+
* - CASE WHEN cond THEN false ELSE elseVal END => AND(NOT(cond), elseVal)
42+
* - CASE WHEN cond THEN false END => false
43+
* - CASE WHEN cond THEN true ELSE elseVal END => OR(cond, elseVal)
44+
* - CASE WHEN cond THEN true END => cond
45+
*/
46+
object SimplifyConditionalsInPredicate extends Rule[LogicalPlan] {
47+
48+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
49+
case f @ Filter(cond, _) => f.copy(condition = simplifyConditional(cond))
50+
case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(simplifyConditional(cond)))
51+
}
52+
53+
private def simplifyConditional(e: Expression): Expression = e match {
54+
case And(left, right) => And(simplifyConditional(left), simplifyConditional(right))
55+
case Or(left, right) => Or(simplifyConditional(left), simplifyConditional(right))
56+
case If(cond, trueValue, FalseLiteral) => And(cond, trueValue)
57+
case If(cond, trueValue, TrueLiteral) => Or(Not(cond), trueValue)
58+
case If(cond, FalseLiteral, falseValue) => And(Not(cond), falseValue)
59+
case If(cond, TrueLiteral, falseValue) => Or(cond, falseValue)
60+
case CaseWhen(Seq((cond, trueValue)),
61+
Some(FalseLiteral) | Some(Literal(null, BooleanType)) | None) =>
62+
And(cond, trueValue)
63+
case CaseWhen(Seq((cond, trueValue)), Some(TrueLiteral)) =>
64+
Or(Not(cond), trueValue)
65+
case CaseWhen(Seq((_, FalseLiteral)), Some(FalseLiteral) | None) =>
66+
FalseLiteral
67+
case CaseWhen(Seq((cond, FalseLiteral)), Some(elseValue)) =>
68+
And(Not(cond), elseValue)
69+
case CaseWhen(Seq((cond, TrueLiteral)), Some(FalseLiteral) | None) =>
70+
cond
71+
case CaseWhen(Seq((cond, TrueLiteral)), Some(elseValue)) =>
72+
Or(cond, elseValue)
73+
case e if e.dataType == BooleanType => e
74+
case e =>
75+
assert(e.dataType != BooleanType,
76+
"Expected a Boolean type expression in SimplifyConditionalsInPredicate, " +
77+
s"but got the type `${e.dataType.catalogString}` in `${e.sql}`.")
78+
e
79+
}
80+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.immutable.HashSet
2121
import scala.collection.mutable.{ArrayBuffer, Stack}
2222

2323
import org.apache.spark.sql.catalyst.analysis._
24-
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, _}
2525
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
2626
import org.apache.spark.sql.catalyst.expressions.aggregate._
2727
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
@@ -461,8 +461,21 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
461461
case If(TrueLiteral, trueValue, _) => trueValue
462462
case If(FalseLiteral, _, falseValue) => falseValue
463463
case If(Literal(null, _), _, falseValue) => falseValue
464+
case If(cond, TrueLiteral, FalseLiteral) =>
465+
if (cond.nullable) EqualNullSafe(cond, TrueLiteral) else cond
466+
case If(cond, FalseLiteral, TrueLiteral) =>
467+
if (cond.nullable) Not(EqualNullSafe(cond, TrueLiteral)) else Not(cond)
464468
case If(cond, trueValue, falseValue)
465469
if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue
470+
case If(cond, l @ Literal(null, _), FalseLiteral) if !cond.nullable => And(cond, l)
471+
case If(cond, l @ Literal(null, _), TrueLiteral) if !cond.nullable => Or(Not(cond), l)
472+
case If(cond, FalseLiteral, l @ Literal(null, _)) if !cond.nullable => And(Not(cond), l)
473+
case If(cond, TrueLiteral, l @ Literal(null, _)) if !cond.nullable => Or(cond, l)
474+
475+
case CaseWhen(Seq((cond, TrueLiteral)), Some(FalseLiteral)) =>
476+
if (cond.nullable) EqualNullSafe(cond, TrueLiteral) else cond
477+
case CaseWhen(Seq((cond, FalseLiteral)), Some(TrueLiteral)) =>
478+
if (cond.nullable) Not(EqualNullSafe(cond, TrueLiteral)) else Not(cond)
466479

467480
case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
468481
// If there are branches that are always false, remove them.
@@ -488,8 +501,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
488501
val (h, t) = branches.span(_._1 != TrueLiteral)
489502
CaseWhen( h :+ t.head, None)
490503

491-
case e @ CaseWhen(branches, Some(elseValue))
492-
if branches.forall(_._2.semanticEquals(elseValue)) =>
504+
case e @ CaseWhen(branches, elseOpt)
505+
if branches.forall(_._2.semanticEquals(elseOpt.getOrElse(Literal(null, e.dataType)))) =>
506+
val elseValue = elseOpt.getOrElse(Literal(null, e.dataType))
493507
// For non-deterministic conditions with side effect, we can not remove it, or change
494508
// the ordering. As a result, we try to remove the deterministic conditions from the tail.
495509
var hitNonDeterministicCond = false
@@ -510,6 +524,88 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
510524
}
511525

512526

527+
/**
528+
* Push the foldable expression into (if / case) branches.
529+
*/
530+
object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
531+
532+
// To be conservative here: it's only a guaranteed win if all but at most only one branch
533+
// end up being not foldable.
534+
private def atMostOneUnfoldable(exprs: Seq[Expression]): Boolean = {
535+
val (foldables, others) = exprs.partition(_.foldable)
536+
foldables.nonEmpty && others.length < 2
537+
}
538+
539+
// Not all UnaryExpression can be pushed into (if / case) branches, e.g. Alias.
540+
private def supportedUnaryExpression(e: UnaryExpression): Boolean = e match {
541+
case _: IsNull | _: IsNotNull => true
542+
case _: UnaryMathExpression | _: Abs | _: Bin | _: Factorial | _: Hex => true
543+
case _: String2StringExpression | _: Ascii | _: Base64 | _: BitLength | _: Chr | _: Length =>
544+
true
545+
case _: CastBase => true
546+
case _: LastDay => true
547+
case _: ExtractIntervalPart => true
548+
case _: ArraySetLike => true
549+
case _: ExtractValue => true
550+
case _ => false
551+
}
552+
553+
// Not all BinaryExpression can be pushed into (if / case) branches.
554+
private def supportedBinaryExpression(e: BinaryExpression): Boolean = e match {
555+
case _: BinaryComparison | _: StringPredicate | _: StringRegexExpression => true
556+
case _: BinaryArithmetic => true
557+
case _: BinaryMathExpression => true
558+
case _: AddMonths | _: DateAdd | _: DateAddInterval | _: DateDiff | _: DateSub => true
559+
case _: FindInSet | _: RoundBase => true
560+
case _ => false
561+
}
562+
563+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
564+
case q: LogicalPlan => q transformExpressionsUp {
565+
case u @ UnaryExpression(i @ If(_, trueValue, falseValue))
566+
if supportedUnaryExpression(u) && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
567+
i.copy(
568+
trueValue = u.withNewChildren(Array(trueValue)),
569+
falseValue = u.withNewChildren(Array(falseValue)))
570+
571+
case u @ UnaryExpression(c @ CaseWhen(branches, elseValue))
572+
if supportedUnaryExpression(u) && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
573+
c.copy(
574+
branches.map(e => e.copy(_2 = u.withNewChildren(Array(e._2)))),
575+
elseValue.map(e => u.withNewChildren(Array(e))))
576+
577+
case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right)
578+
if supportedBinaryExpression(b) && right.foldable &&
579+
atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
580+
i.copy(
581+
trueValue = b.withNewChildren(Array(trueValue, right)),
582+
falseValue = b.withNewChildren(Array(falseValue, right)))
583+
584+
case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue))
585+
if supportedBinaryExpression(b) && left.foldable &&
586+
atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
587+
i.copy(
588+
trueValue = b.withNewChildren(Array(left, trueValue)),
589+
falseValue = b.withNewChildren(Array(left, falseValue)))
590+
591+
case b @ BinaryExpression(c @ CaseWhen(branches, elseValue), right)
592+
if supportedBinaryExpression(b) && right.foldable &&
593+
atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
594+
c.copy(
595+
branches.map(e => e.copy(_2 = b.withNewChildren(Array(e._2, right)))),
596+
elseValue.map(e => b.withNewChildren(Array(e, right))))
597+
598+
case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue))
599+
if supportedBinaryExpression(b) && left.foldable &&
600+
atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
601+
c.copy(
602+
branches.map(e => e.copy(_2 = b.withNewChildren(Array(left, e._2)))),
603+
elseValue.map(e => b.withNewChildren(Array(left, e))))
604+
}
605+
}
606+
}
607+
608+
513609
/**
514610
* Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition.
515611
* For example, when the expression is just checking to see if a string starts with a given

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,14 +218,14 @@ class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper with
218218

219219
test("Complementation Laws - null handling") {
220220
checkCondition('e && !'e,
221-
testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), false)).analyze)
221+
testRelationWithData.where(And(Literal(null, BooleanType), 'e.isNull)).analyze)
222222
checkCondition(!'e && 'e,
223-
testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), false)).analyze)
223+
testRelationWithData.where(And(Literal(null, BooleanType), 'e.isNull)).analyze)
224224

225225
checkCondition('e || !'e,
226-
testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), true)).analyze)
226+
testRelationWithData.where(Or('e.isNotNull, Literal(null, BooleanType))).analyze)
227227
checkCondition(!'e || 'e,
228-
testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), true)).analyze)
228+
testRelationWithData.where(Or('e.isNotNull, Literal(null, BooleanType))).analyze)
229229
}
230230

231231
test("Complementation Laws - negative case") {

0 commit comments

Comments
 (0)