Skip to content

Commit 9686eeb

Browse files
committed
make sure the exprerssion is nullable if it returns null
1 parent d80d0ce commit 9686eeb

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -143,21 +143,26 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
143143
}
144144

145145
test("GetArrayStructFields") {
146-
val typeAS = ArrayType(StructType(StructField("a", IntegerType, false) :: Nil))
147-
val typeNullAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
148-
val arrayStruct = Literal.create(Seq(create_row(1)), typeAS)
149-
val nullArrayStruct = Literal.create(null, typeNullAS)
150-
151-
def getArrayStructFields(expr: Expression, fieldName: String): GetArrayStructFields = {
152-
expr.dataType match {
153-
case ArrayType(StructType(fields), containsNull) =>
154-
val field = fields.find(_.name == fieldName).get
155-
GetArrayStructFields(expr, field, fields.indexOf(field), fields.length, containsNull)
156-
}
146+
// test 4 types: struct field nullability X array element nullability
147+
val type1 = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
148+
val type2 = ArrayType(StructType(StructField("a", IntegerType, nullable = false) :: Nil))
149+
val type3 = ArrayType(StructType(StructField("a", IntegerType) :: Nil), containsNull = false)
150+
val type4 = ArrayType(
151+
StructType(StructField("a", IntegerType, nullable = false) :: Nil), containsNull = false)
152+
153+
val input1 = Literal.create(Seq(create_row(1)), type4)
154+
val input2 = Literal.create(Seq(create_row(null)), type3)
155+
val input3 = Literal.create(Seq(null), type2)
156+
val input4 = Literal.create(null, type1)
157+
158+
def getArrayStructFields(expr: Expression, fieldName: String): Expression = {
159+
ExtractValue.apply(expr, Literal.create(fieldName, StringType), _ == _)
157160
}
158161

159-
checkEvaluation(getArrayStructFields(arrayStruct, "a"), Seq(1))
160-
checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null)
162+
checkEvaluation(getArrayStructFields(input1, "a"), Seq(1))
163+
checkEvaluation(getArrayStructFields(input2, "a"), Seq(null))
164+
checkEvaluation(getArrayStructFields(input3, "a"), Seq(null))
165+
checkEvaluation(getArrayStructFields(input4, "a"), null)
161166
}
162167

163168
test("SPARK-32167: nullability of GetArrayStructFields") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ import org.apache.spark.util.Utils
4040

4141
/**
4242
* A few helper functions for expression evaluation testing. Mixin this trait to use them.
43+
*
44+
* Note: when you write unit test for an expression and call `checkEvaluation` to check the result,
45+
* please make sure that you explore all the cases that can lead to null result (including
46+
* null in struct fields, array elements and map values). The framework will test the
47+
* nullability flag of the expression automatically.
4348
*/
4449
trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestBase {
4550
self: SparkFunSuite =>

0 commit comments

Comments
 (0)