@@ -143,21 +143,26 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
143143 }
144144
145145 test(" GetArrayStructFields" ) {
146- val typeAS = ArrayType (StructType (StructField (" a" , IntegerType , false ) :: Nil ))
147- val typeNullAS = ArrayType (StructType (StructField (" a" , IntegerType ) :: Nil ))
148- val arrayStruct = Literal .create(Seq (create_row(1 )), typeAS)
149- val nullArrayStruct = Literal .create(null , typeNullAS)
150-
151- def getArrayStructFields (expr : Expression , fieldName : String ): GetArrayStructFields = {
152- expr.dataType match {
153- case ArrayType (StructType (fields), containsNull) =>
154- val field = fields.find(_.name == fieldName).get
155- GetArrayStructFields (expr, field, fields.indexOf(field), fields.length, containsNull)
156- }
146+ // test 4 types: struct field nullability X array element nullability
147+ val type1 = ArrayType (StructType (StructField (" a" , IntegerType ) :: Nil ))
148+ val type2 = ArrayType (StructType (StructField (" a" , IntegerType , nullable = false ) :: Nil ))
149+ val type3 = ArrayType (StructType (StructField (" a" , IntegerType ) :: Nil ), containsNull = false )
150+ val type4 = ArrayType (
151+ StructType (StructField (" a" , IntegerType , nullable = false ) :: Nil ), containsNull = false )
152+
153+ val input1 = Literal .create(Seq (create_row(1 )), type4)
154+ val input2 = Literal .create(Seq (create_row(null )), type3)
155+ val input3 = Literal .create(Seq (null ), type2)
156+ val input4 = Literal .create(null , type1)
157+
158+ def getArrayStructFields (expr : Expression , fieldName : String ): Expression = {
159+ ExtractValue .apply(expr, Literal .create(fieldName, StringType ), _ == _)
157160 }
158161
159- checkEvaluation(getArrayStructFields(arrayStruct, " a" ), Seq (1 ))
160- checkEvaluation(getArrayStructFields(nullArrayStruct, " a" ), null )
162+ checkEvaluation(getArrayStructFields(input1, " a" ), Seq (1 ))
163+ checkEvaluation(getArrayStructFields(input2, " a" ), Seq (null ))
164+ checkEvaluation(getArrayStructFields(input3, " a" ), Seq (null ))
165+ checkEvaluation(getArrayStructFields(input4, " a" ), null )
161166 }
162167
163168 test(" SPARK-32167: nullability of GetArrayStructFields" ) {
0 commit comments