Skip to content

Commit 342dd56

Browse files
ueshinJackey Lee
authored andcommitted
[SPARK-26211][SQL][TEST][FOLLOW-UP] Combine test cases for In and InSet.
## What changes were proposed in this pull request? This is a follow pr of apache#23176. `In` and `InSet` are semantically equal, so the tests for `In` should pass with `InSet`, and vice versa. This combines those test cases. ## How was this patch tested? The combined tests and existing tests. Closes apache#23187 from ueshin/issues/SPARK-26211/in_inset_tests. Authored-by: Takuya UESHIN <ueshin@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 14a94fe commit 342dd56

File tree

1 file changed

+66
-94
lines changed

1 file changed

+66
-94
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala

Lines changed: 66 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -124,34 +124,43 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
124124
(null, false, null) ::
125125
(null, null, null) :: Nil)
126126

127-
test("basic IN predicate test") {
128-
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1),
127+
private def checkInAndInSet(in: In, expected: Any): Unit = {
128+
// expecting all in.list are Literal or NonFoldableLiteral.
129+
checkEvaluation(in, expected)
130+
checkEvaluation(InSet(in.value, HashSet() ++ in.list.map(_.eval())), expected)
131+
}
132+
133+
test("basic IN/INSET predicate test") {
134+
checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1),
129135
Literal(2))), null)
130-
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType),
136+
checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType),
131137
Seq(NonFoldableLiteral.create(null, IntegerType))), null)
132-
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null)
133-
checkEvaluation(In(Literal(1), Seq.empty), false)
134-
checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null)
135-
checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
138+
checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null)
139+
checkInAndInSet(In(Literal(1), Seq.empty), false)
140+
checkInAndInSet(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null)
141+
checkInAndInSet(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
136142
true)
137-
checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
143+
checkInAndInSet(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
138144
null)
139-
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
140-
checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
141-
checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false)
145+
checkInAndInSet(In(Literal(1), Seq(Literal(1), Literal(2))), true)
146+
checkInAndInSet(In(Literal(2), Seq(Literal(1), Literal(2))), true)
147+
checkInAndInSet(In(Literal(3), Seq(Literal(1), Literal(2))), false)
148+
142149
checkEvaluation(
143150
And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1),
144151
Literal(2)))),
145152
true)
153+
checkEvaluation(
154+
And(InSet(Literal(1), HashSet(1, 2)), InSet(Literal(2), Set(1, 2))),
155+
true)
146156

147157
val ns = NonFoldableLiteral.create(null, StringType)
148-
checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null)
149-
checkEvaluation(In(ns, Seq(ns)), null)
150-
checkEvaluation(In(Literal("a"), Seq(ns)), null)
151-
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true)
152-
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
153-
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
154-
158+
checkInAndInSet(In(ns, Seq(Literal("1"), Literal("2"))), null)
159+
checkInAndInSet(In(ns, Seq(ns)), null)
160+
checkInAndInSet(In(Literal("a"), Seq(ns)), null)
161+
checkInAndInSet(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true)
162+
checkInAndInSet(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
163+
checkInAndInSet(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
155164
}
156165

157166
test("IN with different types") {
@@ -187,11 +196,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
187196
} else {
188197
false
189198
}
190-
checkEvaluation(In(input(0), input.slice(1, 10)), expected)
199+
checkInAndInSet(In(input(0), input.slice(1, 10)), expected)
191200
}
192201

193202
val atomicTypes = DataTypeTestUtils.atomicTypes.filter { t =>
194-
RandomDataGenerator.forType(t).isDefined && !t.isInstanceOf[DecimalType]
203+
RandomDataGenerator.forType(t).isDefined &&
204+
!t.isInstanceOf[DecimalType] && !t.isInstanceOf[BinaryType]
195205
} ++ Seq(DecimalType.USER_DEFAULT)
196206

197207
val atomicArrayTypes = atomicTypes.map(ArrayType(_, containsNull = true))
@@ -252,93 +262,55 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
252262
assert(ctx.inlinedMutableStates.isEmpty)
253263
}
254264

255-
test("INSET") {
256-
val hS = HashSet[Any]() + 1 + 2
257-
val nS = HashSet[Any]() + 1 + 2 + null
258-
val one = Literal(1)
259-
val two = Literal(2)
260-
val three = Literal(3)
261-
val nl = Literal(null)
262-
checkEvaluation(InSet(one, hS), true)
263-
checkEvaluation(InSet(two, hS), true)
264-
checkEvaluation(InSet(two, nS), true)
265-
checkEvaluation(InSet(three, hS), false)
266-
checkEvaluation(InSet(three, nS), null)
267-
checkEvaluation(InSet(nl, hS), null)
268-
checkEvaluation(InSet(nl, nS), null)
269-
270-
val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
271-
LongType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
272-
primitiveTypes.foreach { t =>
273-
val dataGen = RandomDataGenerator.forType(t, nullable = true).get
274-
val inputData = Seq.fill(10) {
275-
val value = dataGen.apply()
276-
value match {
277-
case d: Double if d.isNaN => 0.0d
278-
case f: Float if f.isNaN => 0.0f
279-
case _ => value
280-
}
281-
}
282-
val input = inputData.map(Literal(_))
283-
val expected = if (inputData(0) == null) {
284-
null
285-
} else if (inputData.slice(1, 10).contains(inputData(0))) {
286-
true
287-
} else if (inputData.slice(1, 10).contains(null)) {
288-
null
289-
} else {
290-
false
291-
}
292-
checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), expected)
293-
}
294-
}
295-
296-
test("INSET: binary") {
297-
val hS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte)
298-
val nS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte) + null
265+
test("IN/INSET: binary") {
299266
val onetwo = Literal(Array(1.toByte, 2.toByte))
300267
val three = Literal(Array(3.toByte))
301268
val threefour = Literal(Array(3.toByte, 4.toByte))
302-
val nl = Literal(null, onetwo.dataType)
303-
checkEvaluation(InSet(onetwo, hS), true)
304-
checkEvaluation(InSet(three, hS), true)
305-
checkEvaluation(InSet(three, nS), true)
306-
checkEvaluation(InSet(threefour, hS), false)
307-
checkEvaluation(InSet(threefour, nS), null)
308-
checkEvaluation(InSet(nl, hS), null)
309-
checkEvaluation(InSet(nl, nS), null)
269+
val nl = NonFoldableLiteral.create(null, onetwo.dataType)
270+
val hS = Seq(Literal(Array(1.toByte, 2.toByte)), Literal(Array(3.toByte)))
271+
val nS = Seq(Literal(Array(1.toByte, 2.toByte)), Literal(Array(3.toByte)),
272+
NonFoldableLiteral.create(null, onetwo.dataType))
273+
checkInAndInSet(In(onetwo, hS), true)
274+
checkInAndInSet(In(three, hS), true)
275+
checkInAndInSet(In(three, nS), true)
276+
checkInAndInSet(In(threefour, hS), false)
277+
checkInAndInSet(In(threefour, nS), null)
278+
checkInAndInSet(In(nl, hS), null)
279+
checkInAndInSet(In(nl, nS), null)
310280
}
311281

312-
test("INSET: struct") {
313-
val hS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value
314-
val nS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value + null
282+
test("IN/INSET: struct") {
315283
val oneA = Literal.create((1, "a"))
316284
val twoB = Literal.create((2, "b"))
317285
val twoC = Literal.create((2, "c"))
318-
val nl = Literal(null, oneA.dataType)
319-
checkEvaluation(InSet(oneA, hS), true)
320-
checkEvaluation(InSet(twoB, hS), true)
321-
checkEvaluation(InSet(twoB, nS), true)
322-
checkEvaluation(InSet(twoC, hS), false)
323-
checkEvaluation(InSet(twoC, nS), null)
324-
checkEvaluation(InSet(nl, hS), null)
325-
checkEvaluation(InSet(nl, nS), null)
286+
val nl = NonFoldableLiteral.create(null, oneA.dataType)
287+
val hS = Seq(Literal.create((1, "a")), Literal.create((2, "b")))
288+
val nS = Seq(Literal.create((1, "a")), Literal.create((2, "b")),
289+
NonFoldableLiteral.create(null, oneA.dataType))
290+
checkInAndInSet(In(oneA, hS), true)
291+
checkInAndInSet(In(twoB, hS), true)
292+
checkInAndInSet(In(twoB, nS), true)
293+
checkInAndInSet(In(twoC, hS), false)
294+
checkInAndInSet(In(twoC, nS), null)
295+
checkInAndInSet(In(nl, hS), null)
296+
checkInAndInSet(In(nl, nS), null)
326297
}
327298

328-
test("INSET: array") {
329-
val hS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value
330-
val nS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value + null
299+
test("IN/INSET: array") {
331300
val onetwo = Literal.create(Seq(1, 2))
332301
val three = Literal.create(Seq(3))
333302
val threefour = Literal.create(Seq(3, 4))
334-
val nl = Literal(null, onetwo.dataType)
335-
checkEvaluation(InSet(onetwo, hS), true)
336-
checkEvaluation(InSet(three, hS), true)
337-
checkEvaluation(InSet(three, nS), true)
338-
checkEvaluation(InSet(threefour, hS), false)
339-
checkEvaluation(InSet(threefour, nS), null)
340-
checkEvaluation(InSet(nl, hS), null)
341-
checkEvaluation(InSet(nl, nS), null)
303+
val nl = NonFoldableLiteral.create(null, onetwo.dataType)
304+
val hS = Seq(Literal.create(Seq(1, 2)), Literal.create(Seq(3)))
305+
val nS = Seq(Literal.create(Seq(1, 2)), Literal.create(Seq(3)),
306+
NonFoldableLiteral.create(null, onetwo.dataType))
307+
checkInAndInSet(In(onetwo, hS), true)
308+
checkInAndInSet(In(three, hS), true)
309+
checkInAndInSet(In(three, nS), true)
310+
checkInAndInSet(In(threefour, hS), false)
311+
checkInAndInSet(In(threefour, nS), null)
312+
checkInAndInSet(In(nl, hS), null)
313+
checkInAndInSet(In(nl, nS), null)
342314
}
343315

344316
private case class MyStruct(a: Long, b: String)

0 commit comments

Comments
 (0)