Skip to content

Commit ee5e6dd

Browse files
committed
fix
1 parent 593678c commit ee5e6dd

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -511,9 +511,13 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
511511
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue)))
512512
}
513513

514-
case EqualTo(CaseWhen(branches, _), right)
515-
if branches.count(_._2.semanticEquals(right)) == 1 =>
516-
branches.filter(_._2.semanticEquals(right)).head._1
514+
case EqualTo(CaseWhen(branches, elseValue), right)
515+
if right.foldable && branches.forall(_._2.foldable) =>
516+
(branches.filter(_._2.equals(right)).map(_._1) ++
517+
elseValue.map(e => EqualTo(e, right))).reduceLeftOption(Or) match {
518+
case Some(value) => value
519+
case None => FalseLiteral
520+
}
517521
}
518522
}
519523
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -202,23 +202,29 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
202202

203203
test("SPARK-33315: simplify CaseWhen with EqualTo") {
204204
val e1 = EqualTo(UnresolvedAttribute("a"), Literal(100))
205-
val e2 = GreaterThan(UnresolvedAttribute("b"), Literal(1000))
206-
val e3 = IsNotNull(UnresolvedAttribute("c"))
205+
val e3 = EqualTo(UnresolvedAttribute("c"), Literal(true))
207206
val caseWhen = CaseWhen(
208-
Seq(normalBranch, (e1, Literal(1)), (e2, Literal(2)), (e3, Literal(3))), None)
209-
assertEquivalent(EqualTo(caseWhen, Literal(1)), e1)
210-
assertEquivalent(EqualTo(caseWhen, Literal(3)), e3)
207+
Seq(normalBranch, (e1, Literal(1)), (e3, Literal(2))), Some(UnresolvedAttribute("b")))
211208

212-
assertEquivalent(
213-
And(EqualTo(caseWhen, Literal(1)), EqualTo(caseWhen, Literal(2))),
214-
And(e1, e2))
215-
assertEquivalent(
216-
Or(EqualTo(caseWhen, Literal(1)), EqualTo(caseWhen, Literal(2))),
217-
Or(e1, e2))
209+
assertEquivalent(EqualTo(caseWhen, Literal(1)),
210+
Or(e1, EqualTo(UnresolvedAttribute("b"), Literal(1))))
211+
assertEquivalent(EqualTo(caseWhen, Literal(3)),
212+
EqualTo(UnresolvedAttribute("b"), Literal(3)))
213+
assertEquivalent(EqualTo(caseWhen, Literal(4)),
214+
EqualTo(UnresolvedAttribute("b"), Literal(4)))
215+
216+
assertEquivalent(And(EqualTo(caseWhen, Literal(1)), EqualTo(caseWhen, Literal(2))),
217+
And(Or(e1, EqualTo(UnresolvedAttribute("b"), Literal(1))),
218+
Or(e3, EqualTo(UnresolvedAttribute("b"), Literal(2)))))
218219

219-
assertEquivalent(EqualTo(caseWhen, Literal(4)), EqualTo(caseWhen, Literal(4)))
220220
assertEquivalent(
221-
Or(EqualTo(caseWhen, Literal(3)), EqualTo(caseWhen, Literal(4))),
222-
Or(e3, EqualTo(caseWhen, Literal(4))))
221+
EqualTo(CaseWhen(Seq(normalBranch, (e1, Literal(1)), (e3, Literal(2))), None), Literal(3)),
222+
FalseLiteral)
223+
224+
// Do not simplify if it contains non foldable expressions.
225+
assertEquivalent(EqualTo(caseWhen, NonFoldableLiteral(true)),
226+
EqualTo(caseWhen, NonFoldableLiteral(true)))
227+
val nonFoldable = CaseWhen(Seq(normalBranch, (e1, UnresolvedAttribute("b"))), None)
228+
assertEquivalent(EqualTo(nonFoldable, Literal(1)), EqualTo(nonFoldable, Literal(1)))
223229
}
224230
}

0 commit comments

Comments
 (0)