@@ -46,7 +46,9 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
4646 private val unreachableBranch = (FalseLiteral , Literal (20 ))
4747 private val nullBranch = (Literal .create(null , NullType ), Literal (30 ))
4848
49- private val testRelation = LocalRelation (' a .int)
49+ val isNotNullCond = IsNotNull (UnresolvedAttribute (" a" ))
50+ val isNullCond = IsNull (UnresolvedAttribute (" a" ))
51+ val notCond = Not (UnresolvedAttribute (" c" ))
5052
5153 test(" simplify if" ) {
5254 assertEquivalent(
@@ -122,4 +124,54 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
122124 None ),
123125 CaseWhen (normalBranch :: trueBranch :: Nil , None ))
124126 }
127+
128+ test(" remove a branch in CaseWhen if a cond in this branch is previously seen" ) {
129+ assertEquivalent(
130+ CaseWhen ((GreaterThan (Rand (0 ), Literal (0.5 )), Literal (1 )) ::
131+ (GreaterThan (Rand (0 ), Literal (0.5 )), Literal (2 )) ::
132+ (NonFoldableLiteral (true ), Literal (3 )) ::
133+ (LessThan (Rand (1 ), Literal (0.5 )), Literal (4 )) ::
134+ (NonFoldableLiteral (true ), Literal (5 )) ::
135+ (NonFoldableLiteral (false ), Literal (6 )) ::
136+ (NonFoldableLiteral (false ), Literal (7 )) ::
137+ Nil ,
138+ None ),
139+ CaseWhen ((GreaterThan (Rand (0 ), Literal (0.5 )), Literal (1 )) ::
140+ (GreaterThan (Rand (0 ), Literal (0.5 )), Literal (2 )) ::
141+ (NonFoldableLiteral (true ), Literal (3 )) ::
142+ (LessThan (Rand (1 ), Literal (0.5 )), Literal (4 )) ::
143+ (NonFoldableLiteral (false ), Literal (6 )) ::
144+ Nil ,
145+ None )
146+ )
147+ }
148+
149+ test(" combine two adjacent branches in CaseWhen if they have the same output values" ) {
150+ assertEquivalent(
151+ CaseWhen ((GreaterThan (Rand (0 ), Literal (0.5 )), Literal (1 )) ::
152+ (NonFoldableLiteral (true ), Literal (1 )) ::
153+ (LessThan (Rand (1 ), Literal (0.5 )), Literal (3 )) ::
154+ (NonFoldableLiteral (true ), Literal (3 )) ::
155+ (NonFoldableLiteral (false ), Literal (4 )) ::
156+ Nil ,
157+ None ),
158+ CaseWhen ((Or (GreaterThan (Rand (0 ), Literal (0.5 )), NonFoldableLiteral (true )), Literal (1 )) ::
159+ (Or (LessThan (Rand (1 ), Literal (0.5 )), NonFoldableLiteral (true )), Literal (3 )) ::
160+ (NonFoldableLiteral (false ), Literal (4 )) ::
161+ Nil ,
162+ None )
163+ )
164+
165+ // The first two conditions can be combined, and then the optimizer uses rule in `Or`
166+ // to be optimized into `TrueLiteral`. Thus, the entire `CaseWhen` can be removed.
167+ assertEquivalent(
168+ CaseWhen ((UnresolvedAttribute (" a" ), Literal (1 )) ::
169+ (Not (UnresolvedAttribute (" a" )), Literal (1 )) ::
170+ (LessThan (Rand (1 ), Literal (0.5 )), Literal (3 )) ::
171+ (NonFoldableLiteral (true ), Literal (4 )) ::
172+ (NonFoldableLiteral (false ), Literal (5 )) ::
173+ Nil ,
174+ None ),
175+ Literal (1 ))
176+ }
125177}
0 commit comments