@@ -202,23 +202,29 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
202202
203203 test(" SPARK-33315: simplify CaseWhen with EqualTo" ) {
204204 val e1 = EqualTo (UnresolvedAttribute (" a" ), Literal (100 ))
205- val e2 = GreaterThan (UnresolvedAttribute (" b" ), Literal (1000 ))
206- val e3 = IsNotNull (UnresolvedAttribute (" c" ))
205+ val e3 = EqualTo (UnresolvedAttribute (" c" ), Literal (true ))
207206 val caseWhen = CaseWhen (
208- Seq (normalBranch, (e1, Literal (1 )), (e2, Literal (2 )), (e3, Literal (3 ))), None )
209- assertEquivalent(EqualTo (caseWhen, Literal (1 )), e1)
210- assertEquivalent(EqualTo (caseWhen, Literal (3 )), e3)
207+ Seq (normalBranch, (e1, Literal (1 )), (e3, Literal (2 ))), Some (UnresolvedAttribute (" b" )))
211208
212- assertEquivalent(
213- And (EqualTo (caseWhen, Literal (1 )), EqualTo (caseWhen, Literal (2 ))),
214- And (e1, e2))
215- assertEquivalent(
216- Or (EqualTo (caseWhen, Literal (1 )), EqualTo (caseWhen, Literal (2 ))),
217- Or (e1, e2))
209+ assertEquivalent(EqualTo (caseWhen, Literal (1 )),
210+ Or (e1, EqualTo (UnresolvedAttribute (" b" ), Literal (1 ))))
211+ assertEquivalent(EqualTo (caseWhen, Literal (3 )),
212+ EqualTo (UnresolvedAttribute (" b" ), Literal (3 )))
213+ assertEquivalent(EqualTo (caseWhen, Literal (4 )),
214+ EqualTo (UnresolvedAttribute (" b" ), Literal (4 )))
215+
216+ assertEquivalent(And (EqualTo (caseWhen, Literal (1 )), EqualTo (caseWhen, Literal (2 ))),
217+ And (Or (e1, EqualTo (UnresolvedAttribute (" b" ), Literal (1 ))),
218+ Or (e3, EqualTo (UnresolvedAttribute (" b" ), Literal (2 )))))
218219
219- assertEquivalent(EqualTo (caseWhen, Literal (4 )), EqualTo (caseWhen, Literal (4 )))
220220 assertEquivalent(
221- Or (EqualTo (caseWhen, Literal (3 )), EqualTo (caseWhen, Literal (4 ))),
222- Or (e3, EqualTo (caseWhen, Literal (4 ))))
221+ EqualTo (CaseWhen (Seq (normalBranch, (e1, Literal (1 )), (e3, Literal (2 ))), None ), Literal (3 )),
222+ FalseLiteral )
223+
224+ // Do not simplify if it contains non foldable expressions.
225+ assertEquivalent(EqualTo (caseWhen, NonFoldableLiteral (true )),
226+ EqualTo (caseWhen, NonFoldableLiteral (true )))
227+ val nonFoldable = CaseWhen (Seq (normalBranch, (e1, UnresolvedAttribute (" b" ))), None )
228+ assertEquivalent(EqualTo (nonFoldable, Literal (1 )), EqualTo (nonFoldable, Literal (1 )))
223229 }
224230}
0 commit comments