diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 7683ee7074e7..90ec699877de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -418,6 +418,7 @@ object JavaTypeInference { inputObject, ObjectType(keyType.getRawType), serializerFor(_, keyType), + keyNullable = true, ObjectType(valueType.getRawType), serializerFor(_, valueType), valueNullable = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d580cf4d3391..f3c1e4150017 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -494,6 +494,7 @@ object ScalaReflection extends ScalaReflection { inputObject, dataTypeFor(keyType), serializerFor(_, keyType, keyPath, seenTypeSet), + keyNullable = !keyType.typeSymbol.asClass.isPrimitive, dataTypeFor(valueType), serializerFor(_, valueType, valuePath, seenTypeSet), valueNullable = !valueType.typeSymbol.asClass.isPrimitive) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 4b651836ff4d..d6d06aecc077 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -841,18 +841,21 @@ object ExternalMapToCatalyst { inputMap: Expression, keyType: DataType, keyConverter: Expression => Expression, + keyNullable: Boolean, valueType: DataType, valueConverter: Expression => Expression, valueNullable: Boolean): ExternalMapToCatalyst = { val id = curId.getAndIncrement() val keyName = "ExternalMapToCatalyst_key" + id + val keyIsNull = "ExternalMapToCatalyst_key_isNull" + id val valueName = "ExternalMapToCatalyst_value" + id val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id ExternalMapToCatalyst( keyName, + keyIsNull, keyType, - keyConverter(LambdaVariable(keyName, "false", keyType, false)), + keyConverter(LambdaVariable(keyName, keyIsNull, keyType, keyNullable)), valueName, valueIsNull, valueType, @@ -868,6 +871,8 @@ object ExternalMapToCatalyst { * * @param key the name of the map key variable that used when iterate the map, and used as input for * the `keyConverter` + * @param keyIsNull the nullability of the map key variable that used when iterate the map, and + * used as input for the `keyConverter` * @param keyType the data type of the map key variable that used when iterate the map, and used as * input for the `keyConverter` * @param keyConverter A function that take the `key` as input, and converts it to catalyst format. @@ -883,6 +888,7 @@ object ExternalMapToCatalyst { */ case class ExternalMapToCatalyst private( key: String, + keyIsNull: String, keyType: DataType, keyConverter: Expression, value: String, @@ -913,6 +919,7 @@ case class ExternalMapToCatalyst private( val keyElementJavaType = ctx.javaType(keyType) val valueElementJavaType = ctx.javaType(valueType) + ctx.addMutableState("boolean", keyIsNull, "") ctx.addMutableState(keyElementJavaType, key, "") ctx.addMutableState("boolean", valueIsNull, "") ctx.addMutableState(valueElementJavaType, value, "") @@ -950,6 +957,12 @@ case class ExternalMapToCatalyst private( defineEntries -> defineKeyValue } + val keyNullCheck = if (ctx.isPrimitiveType(keyType)) { + s"$keyIsNull = false;" + } else { + s"$keyIsNull = $key == null;" + } + val valueNullCheck = if (ctx.isPrimitiveType(valueType)) { s"$valueIsNull = false;" } else { @@ -972,6 +985,7 @@ case class ExternalMapToCatalyst private( $defineEntries while($entries.hasNext()) { $defineKeyValue + $keyNullCheck $valueNullCheck ${genKeyConverter.code} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 080f11b76938..bb1955a1ae24 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -355,12 +355,18 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { checkNullable[String](true) } - test("null check for map key") { + test("null check for map key: String") { val encoder = ExpressionEncoder[Map[String, Int]]() val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2)))) assert(e.getMessage.contains("Cannot use null as map key")) } + test("null check for map key: Integer") { + val encoder = ExpressionEncoder[Map[Integer, String]]() + val e = intercept[RuntimeException](encoder.toRow(Map((1, "a"), (null, "b")))) + assert(e.getMessage.contains("Cannot use null as map key")) + } + private def encodeDecodeTest[T : ExpressionEncoder]( input: T, testName: String): Unit = {