From 8c6039c7b7f31f0343c4b0098a4e12dfff125128 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Mon, 7 May 2018 14:23:18 +0200 Subject: [PATCH 1/5] [SPARK-23934][SQL] Adding map_from_entries function --- python/pyspark/sql/functions.py | 20 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 227 +++++++++++++++++- .../CollectionExpressionsSuite.scala | 58 +++++ .../org/apache/spark/sql/functions.scala | 7 + .../spark/sql/DataFrameFunctionsSuite.scala | 52 ++++ 6 files changed, 364 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ac3c79766702..fa804ccc0844 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2304,6 +2304,26 @@ def map_values(col): return Column(sc._jvm.functions.map_values(_to_java_column(col))) +@since(2.4) +def map_from_entries(col): + """ + Collection function: Returns a map created from the given array of entries. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_from_entries + >>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data") + >>> df.select(map_from_entries("data").alias("map")).show() + +----------------+ + | map| + +----------------+ + |[1 -> a, 2 -> b]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_from_entries(_to_java_column(col))) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 87b0911e150c..e41e889762dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -409,6 +409,7 @@ object FunctionRegistry { expression[ElementAt]("element_at"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), + expression[MapFromEntries]("map_from_entries"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 12b9ab2b272a..9334953d64d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -22,10 +22,12 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.util.collection.OpenHashSet /** * Given an array or map, returns its size. Returns -1 if null. @@ -118,6 +120,229 @@ case class MapValues(child: Expression) override def prettyName: String = "map_values" } +/** + * Returns a map created from the given array of entries. + */ +@ExpressionDescription( + usage = "_FUNC_(arrayOfEntries) - Returns a map created from the given array of entries.", + examples = """ + Examples: + > SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b'))); + {1:"a",2:"b"} + """, + since = "2.4.0") +case class MapFromEntries(child: Expression) extends UnaryExpression +{ + private lazy val resolvedDataType: Option[MapType] = child.dataType match { + case ArrayType( + StructType(Array( + StructField(_, keyType, false, _), + StructField(_, valueType, valueNullable, _))), + false) => Some(MapType(keyType, valueType, valueNullable)) + case _ => None + } + + override def dataType: MapType = resolvedDataType.get + + override def checkInputDataTypes(): TypeCheckResult = resolvedDataType match { + case Some(_) => TypeCheckResult.TypeCheckSuccess + case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " + + s"${child.dataType.simpleString} type. $prettyName accepts only null-free arrays " + + "of pair structs. Values of the first struct field can't contain nulls and produce " + + "duplicates.") + } + + override protected def nullSafeEval(input: Any): Any = { + val arrayData = input.asInstanceOf[ArrayData] + val length = arrayData.numElements() + val keyArray = new Array[AnyRef](length) + val keySet = new OpenHashSet[AnyRef]() + val valueArray = new Array[AnyRef](length) + var i = 0; + while (i < length) { + val entry = arrayData.getStruct(i, 2) + val key = entry.get(0, dataType.keyType) + if (key == null) { + throw new RuntimeException("The first field from a struct (key) can't be null.") + } + if (keySet.contains(key)) { + throw new RuntimeException("The first field from a struct (key) can't produce duplicates.") + } + keySet.add(key) + keyArray.update(i, key) + val value = entry.get(1, dataType.valueType) + valueArray.update(i, value) + i += 1 + } + ArrayBasedMapData(keyArray, valueArray) + } + + private def getHashSetDetails(): (String, String) = dataType.keyType match { + case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") + case LongType => ("$mcJ$sp", "Long") + case _ => ("", "Object") + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numElements = ctx.freshName("numElements") + val keySet = ctx.freshName("keySet") + val hsClass = classOf[OpenHashSet[_]].getName + val tagPrefix = "scala.reflect.ClassTag$.MODULE$." + val (hsSuffix, tagSuffix) = getHashSetDetails() + val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType) + val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) + val code = if (isKeyPrimitive && isValuePrimitive) { + genCodeForPrimitiveElements(ctx, c, ev.value, keySet, numElements) + } else { + genCodeForAnyElements(ctx, c, ev.value, keySet, numElements) + } + s""" + |final int $numElements = $c.numElements(); + |final $hsClass$hsSuffix $keySet = new $hsClass$hsSuffix($tagPrefix$tagSuffix()); + |$code + """.stripMargin + }) + } + + private def genCodeForAssignmentLoop( + ctx: CodegenContext, + childVariable: String, + numElements: String, + keySet: String, + keyAssignment: (String, String) => String, + valueAssignment: (String, String) => String): String = { + val entry = ctx.freshName("entry") + val key = ctx.freshName("key") + val idx = ctx.freshName("idx") + val keyType = CodeGenerator.javaType(dataType.keyType) + + s""" + |for (int $idx = 0; $idx < $numElements; $idx++) { + | InternalRow $entry = $childVariable.getStruct($idx, 2); + | if ($entry.isNullAt(0)) { + | throw new RuntimeException("The first field from a struct (key) can't be null."); + | } + | $keyType $key = ${CodeGenerator.getValue(entry, dataType.keyType, "0")}; + | if ($keySet.contains($key)) { + | throw new RuntimeException( + | "The first field from a struct (key) can't produce duplicates."); + | } + | $keySet.add($key); + | ${keyAssignment(key, idx)} + | ${valueAssignment(entry, idx)} + |} + """.stripMargin + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + childVariable: String, + mapData: String, + keySet: String, + numElements: String): String = { + val byteArraySize = ctx.freshName("byteArraySize") + val keySectionSize = ctx.freshName("keySectionSize") + val valueSectionSize = ctx.freshName("valueSectionSize") + val data = ctx.freshName("byteArray") + val unsafeMapData = ctx.freshName("unsafeMapData") + val keyArrayData = ctx.freshName("keyArrayData") + val valueArrayData = ctx.freshName("valueArrayData") + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val keySize = dataType.keyType.defaultSize + val valueSize = dataType.valueType.defaultSize + val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numElements, $keySize)" + val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numElements, $valueSize)" + val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType) + + val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);" + val valueAssignment = (entry: String, idx: String) => { + val value = CodeGenerator.getValue(entry, dataType.valueType, "1") + val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);" + if (dataType.valueContainsNull) { + s""" + |if ($entry.isNullAt(1)) { + | $valueArrayData.setNullAt($idx); + |} else { + | $valueNullUnsafeAssignment + |} + """.stripMargin + } else { + valueNullUnsafeAssignment + } + } + val assignmentLoop = genCodeForAssignmentLoop( + ctx, + childVariable, + numElements, + keySet, + keyAssignment, + valueAssignment + ) + + s""" + |final long $keySectionSize = $kByteSize; + |final long $valueSectionSize = $vByteSize; + |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize; + |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | ${genCodeForAnyElements(ctx, childVariable, mapData, keySet, numElements)} + |} else { + | final byte[] $data = new byte[(int)$byteArraySize]; + | UnsafeMapData $unsafeMapData = new UnsafeMapData(); + | Platform.putLong($data, $baseOffset, $keySectionSize); + | Platform.putLong($data, ${baseOffset + 8}, $numElements); + | Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numElements); + | $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize); + | ArrayData $keyArrayData = $unsafeMapData.keyArray(); + | ArrayData $valueArrayData = $unsafeMapData.valueArray(); + | $assignmentLoop + | $mapData = $unsafeMapData; + |} + """.stripMargin + } + + private def genCodeForAnyElements( + ctx: CodegenContext, + childVariable: String, + mapData: String, + keySet: String, + numElements: String): String = { + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val mapDataClass = classOf[ArrayBasedMapData].getName() + + val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) + val valueAssignment = (entry: String, idx: String) => { + val value = CodeGenerator.getValue(entry, dataType.valueType, "1") + if (dataType.valueContainsNull && isValuePrimitive) { + s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;" + } else { + s"$values[$idx] = $value;" + } + } + val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;" + val assignmentLoop = genCodeForAssignmentLoop( + ctx, + childVariable, + numElements, + keySet, + keyAssignment, + valueAssignment) + + s""" + |final Object[] $keys = new Object[$numElements]; + |final Object[] $values = new Object[$numElements]; + |$assignmentLoop + |$mapData = $mapDataClass.apply($keys, $values); + """.stripMargin + } + + override def prettyName: String = "map_from_entries" +} + + /** * Common base class for [[SortArray]] and [[ArraySort]]. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index a2851d071c7c..ae9f31231f5a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -56,6 +57,63 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapValues(m2), null) } + test("MapFromEntries") { + def arrayType(keyType: DataType, valueType: DataType) : DataType = { + ArrayType(StructType(Seq( + StructField("a", keyType, false), + StructField("b", valueType))), + false) + } + def r(values: Any*): InternalRow = create_row(values: _*) + + // Primitive-type keys and values + val aiType = arrayType(IntegerType, IntegerType) + val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType) + val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType) + val ai2 = Literal.create(Seq.empty, aiType) + val ai3 = Literal.create(null, aiType) + val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType) + val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType) + val aby = Literal.create(Seq(r(1.toByte, 10.toByte)), arrayType(ByteType, ByteType)) + val ash = Literal.create(Seq(r(1.toShort, 10.toShort)), arrayType(ShortType, ShortType)) + val alo = Literal.create(Seq(r(1L, 10L)), arrayType(LongType, LongType)) + + checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20)) + checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null)) + checkEvaluation(MapFromEntries(ai2), Map.empty) + checkEvaluation(MapFromEntries(ai3), null) + checkExceptionInExpression[RuntimeException]( + MapFromEntries(ai4), + "The first field from a struct (key) can't produce duplicates.") + checkExceptionInExpression[RuntimeException]( + MapFromEntries(ai5), + "The first field from a struct (key) can't be null.") + checkEvaluation(MapFromEntries(aby), Map(1.toByte -> 10.toByte)) + checkEvaluation(MapFromEntries(ash), Map(1.toShort -> 10.toShort)) + checkEvaluation(MapFromEntries(alo), Map(1L -> 10L)) + + // Non-primitive-type keys and values + val asType = arrayType(StringType, StringType) + val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType) + val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType) + val as2 = Literal.create(Seq.empty, asType) + val as3 = Literal.create(null, asType) + val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType) + val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType) + + checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb")) + checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null)) + checkEvaluation(MapFromEntries(as2), Map.empty) + checkEvaluation(MapFromEntries(as3), null) + checkExceptionInExpression[RuntimeException]( + MapFromEntries(as4), + "The first field from a struct (key) can't produce duplicates.") + checkExceptionInExpression[RuntimeException]( + MapFromEntries(as5), + "The first field from a struct (key) can't be null.") + + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8f9e4ae18b3f..222d8f8d25a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3414,6 +3414,13 @@ object functions { */ def map_values(e: Column): Column = withExpr { MapValues(e.expr) } + /** + * Returns a map created from the given array of entries. + * @group collection_funcs + * @since 2.4.0 + */ + def map_from_entries(e: Column): Column = withExpr { MapFromEntries(e.expr) } + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ecce06f4c075..196194c7340d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -405,6 +405,58 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("map_from_entries function") { + def dummy_filter(c: Column): Column = c.isNull || c.isNotNull + + def arrayType(keyType: DataType, valueType: DataType) : StructType = { + StructType(Seq(StructField( + "a", + ArrayType( + StructType(Seq( + StructField("k", keyType, false), + StructField("v", valueType, true))), + false)))) + } + + // Test cases with primitive-type keys and values + val irdd = spark.sparkContext.parallelize(Seq( + Row(Seq(Row(1, 10), Row(2, 20), Row(3, 10))), + Row(Seq(Row(1, null), Row(2, null))), + Row(Seq.empty), + Row(null))) + val idf = spark.createDataFrame(irdd, arrayType(IntegerType, IntegerType)) + val iExpected = Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 10)), + Row(Map(1 -> null, 2 -> null)), + Row(Map.empty), + Row(null)) + + checkAnswer(idf.select(map_from_entries('a)), iExpected) + checkAnswer(idf.selectExpr("map_from_entries(a)"), iExpected) + checkAnswer(idf.filter(dummy_filter('a)).select(map_from_entries('a)), iExpected) + + // Test cases with non-primitive-type keys and values + val srdd = spark.sparkContext.parallelize(Seq( + Row(Seq(Row("a", "aa"), Row("b", "bb"), Row("c", "aa"))), + Row(Seq(Row("a", null), Row("b", null))), + Row(Seq.empty), + Row(null))) + val sdf = spark.createDataFrame(srdd, arrayType(StringType, StringType)) + val sExpected = Seq( + Row(Map("a" -> "aa", "b" -> "bb", "c" -> "aa")), + Row(Map("a" -> null, "b" -> null)), + Row(Map.empty), + Row(null)) + + checkAnswer(sdf.select(map_from_entries('a)), sExpected) + checkAnswer(sdf.selectExpr("map_from_entries(a)"), sExpected) + checkAnswer(sdf.filter(dummy_filter('a)).select(map_from_entries('a)), sExpected) + + // Error test cases + intercept[AnalysisException](idf.selectExpr("map_from_entries(array(struct(null, 1)))")) + intercept[AnalysisException](idf.selectExpr("map_from_entries(array(struct(1, 10), null))")) + } + test("array contains function") { val df = Seq( (Seq[Int](1, 2), "x"), From 25aa87976ce22aad3ceb7ca66bce5a98d185746b Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Thu, 10 May 2018 15:06:29 +0200 Subject: [PATCH 2/5] [SPARK-23934][SQL] Addressing review comments --- .../expressions/collectionOperations.scala | 129 +++++++++++------- .../CollectionExpressionsSuite.scala | 14 +- .../spark/sql/DataFrameFunctionsSuite.scala | 58 ++++---- 3 files changed, 119 insertions(+), 82 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 9334953d64d6..d3d530b224c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -131,47 +131,54 @@ case class MapValues(child: Expression) {1:"a",2:"b"} """, since = "2.4.0") -case class MapFromEntries(child: Expression) extends UnaryExpression -{ - private lazy val resolvedDataType: Option[MapType] = child.dataType match { +case class MapFromEntries(child: Expression) extends UnaryExpression { + + @transient + private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { case ArrayType( StructType(Array( - StructField(_, keyType, false, _), + StructField(_, keyType, keyNullable, _), StructField(_, valueType, valueNullable, _))), - false) => Some(MapType(keyType, valueType, valueNullable)) + containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull)) case _ => None } - override def dataType: MapType = resolvedDataType.get + private def nullEntries: Boolean = dataTypeDetails.get._3 + + override def dataType: MapType = dataTypeDetails.get._1 - override def checkInputDataTypes(): TypeCheckResult = resolvedDataType match { + override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { case Some(_) => TypeCheckResult.TypeCheckSuccess case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " + - s"${child.dataType.simpleString} type. $prettyName accepts only null-free arrays " + - "of pair structs. Values of the first struct field can't contain nulls and produce " + - "duplicates.") + s"${child.dataType.simpleString} type. $prettyName accepts only arrays of pair structs.") } override protected def nullSafeEval(input: Any): Any = { val arrayData = input.asInstanceOf[ArrayData] val length = arrayData.numElements() - val keyArray = new Array[AnyRef](length) + val numEntries = if (nullEntries) (0 until length).count(!arrayData.isNullAt(_)) else length + val keyArray = new Array[AnyRef](numEntries) val keySet = new OpenHashSet[AnyRef]() - val valueArray = new Array[AnyRef](length) + val valueArray = new Array[AnyRef](numEntries) var i = 0; + var j = 0; while (i < length) { - val entry = arrayData.getStruct(i, 2) - val key = entry.get(0, dataType.keyType) - if (key == null) { - throw new RuntimeException("The first field from a struct (key) can't be null.") - } - if (keySet.contains(key)) { - throw new RuntimeException("The first field from a struct (key) can't produce duplicates.") + if (!arrayData.isNullAt(i)) { + val entry = arrayData.getStruct(i, 2) + val key = entry.get(0, dataType.keyType) + if (key == null) { + throw new RuntimeException("The first field from a struct (key) can't be null.") + } + if (keySet.contains(key)) { + throw new RuntimeException( + "The first field from a struct (key) can't produce duplicates.") + } + keySet.add(key) + keyArray.update(j, key) + val value = entry.get(1, dataType.valueType) + valueArray.update(j, value) + j += 1 } - keySet.add(key) - keyArray.update(i, key) - val value = entry.get(1, dataType.valueType) - valueArray.update(i, value) i += 1 } ArrayBasedMapData(keyArray, valueArray) @@ -185,7 +192,8 @@ case class MapFromEntries(child: Expression) extends UnaryExpression override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { - val numElements = ctx.freshName("numElements") + val length = ctx.freshName("length") + val numEntries = ctx.freshName("numEntries") val keySet = ctx.freshName("keySet") val hsClass = classOf[OpenHashSet[_]].getName val tagPrefix = "scala.reflect.ClassTag$.MODULE$." @@ -193,12 +201,25 @@ case class MapFromEntries(child: Expression) extends UnaryExpression val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType) val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) val code = if (isKeyPrimitive && isValuePrimitive) { - genCodeForPrimitiveElements(ctx, c, ev.value, keySet, numElements) + genCodeForPrimitiveElements(ctx, c, ev.value, keySet, length, numEntries) + } else { + genCodeForAnyElements(ctx, c, ev.value, keySet, length, numEntries) + } + val numEntriesAssignment = if (nullEntries) { + val idx = ctx.freshName("idx") + s""" + |int $numEntries = 0; + |for (int $idx = 0; $idx < $length; $idx++) { + | if (!$c.isNullAt($idx)) $numEntries++; + |} + """.stripMargin } else { - genCodeForAnyElements(ctx, c, ev.value, keySet, numElements) + s"final int $numEntries = $length;" } + s""" - |final int $numElements = $c.numElements(); + |final int $length = $c.numElements(); + |$numEntriesAssignment |final $hsClass$hsSuffix $keySet = new $hsClass$hsSuffix($tagPrefix$tagSuffix()); |$code """.stripMargin @@ -208,29 +229,41 @@ case class MapFromEntries(child: Expression) extends UnaryExpression private def genCodeForAssignmentLoop( ctx: CodegenContext, childVariable: String, - numElements: String, + length: String, keySet: String, keyAssignment: (String, String) => String, valueAssignment: (String, String) => String): String = { val entry = ctx.freshName("entry") val key = ctx.freshName("key") - val idx = ctx.freshName("idx") + val i = ctx.freshName("idx") + val j = ctx.freshName("idx") + val keyType = CodeGenerator.javaType(dataType.keyType) + val nullEntryCheck = if (nullEntries) s"if ($childVariable.isNullAt($i)) continue;" else "" + val nullKeyCheck = if (dataTypeDetails.get._2) { + s""" + |if ($entry.isNullAt(0)) { + | throw new RuntimeException("The first field from a struct (key) can't be null."); + |} + """.stripMargin + } else { + "" + } s""" - |for (int $idx = 0; $idx < $numElements; $idx++) { - | InternalRow $entry = $childVariable.getStruct($idx, 2); - | if ($entry.isNullAt(0)) { - | throw new RuntimeException("The first field from a struct (key) can't be null."); - | } + |for (int $i = 0, $j = 0; $i < $length; $i++) { + | $nullEntryCheck + | InternalRow $entry = $childVariable.getStruct($i, 2); + | $nullKeyCheck | $keyType $key = ${CodeGenerator.getValue(entry, dataType.keyType, "0")}; | if ($keySet.contains($key)) { | throw new RuntimeException( | "The first field from a struct (key) can't produce duplicates."); | } | $keySet.add($key); - | ${keyAssignment(key, idx)} - | ${valueAssignment(entry, idx)} + | ${keyAssignment(key, j)} + | ${valueAssignment(entry, j)} + | $j++; |} """.stripMargin } @@ -240,7 +273,8 @@ case class MapFromEntries(child: Expression) extends UnaryExpression childVariable: String, mapData: String, keySet: String, - numElements: String): String = { + length: String, + numEntries: String): String = { val byteArraySize = ctx.freshName("byteArraySize") val keySectionSize = ctx.freshName("keySectionSize") val valueSectionSize = ctx.freshName("valueSectionSize") @@ -252,8 +286,8 @@ case class MapFromEntries(child: Expression) extends UnaryExpression val baseOffset = Platform.BYTE_ARRAY_OFFSET val keySize = dataType.keyType.defaultSize val valueSize = dataType.valueType.defaultSize - val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numElements, $keySize)" - val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numElements, $valueSize)" + val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)" + val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)" val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType) val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType) @@ -276,7 +310,7 @@ case class MapFromEntries(child: Expression) extends UnaryExpression val assignmentLoop = genCodeForAssignmentLoop( ctx, childVariable, - numElements, + length, keySet, keyAssignment, valueAssignment @@ -287,13 +321,13 @@ case class MapFromEntries(child: Expression) extends UnaryExpression |final long $valueSectionSize = $vByteSize; |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize; |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | ${genCodeForAnyElements(ctx, childVariable, mapData, keySet, numElements)} + | ${genCodeForAnyElements(ctx, childVariable, mapData, keySet, length, numEntries)} |} else { | final byte[] $data = new byte[(int)$byteArraySize]; | UnsafeMapData $unsafeMapData = new UnsafeMapData(); | Platform.putLong($data, $baseOffset, $keySectionSize); - | Platform.putLong($data, ${baseOffset + 8}, $numElements); - | Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numElements); + | Platform.putLong($data, ${baseOffset + 8}, $numEntries); + | Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries); | $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize); | ArrayData $keyArrayData = $unsafeMapData.keyArray(); | ArrayData $valueArrayData = $unsafeMapData.valueArray(); @@ -308,7 +342,8 @@ case class MapFromEntries(child: Expression) extends UnaryExpression childVariable: String, mapData: String, keySet: String, - numElements: String): String = { + length: String, + numEntries: String): String = { val keys = ctx.freshName("keys") val values = ctx.freshName("values") val mapDataClass = classOf[ArrayBasedMapData].getName() @@ -326,14 +361,14 @@ case class MapFromEntries(child: Expression) extends UnaryExpression val assignmentLoop = genCodeForAssignmentLoop( ctx, childVariable, - numElements, + length, keySet, keyAssignment, valueAssignment) s""" - |final Object[] $keys = new Object[$numElements]; - |final Object[] $values = new Object[$numElements]; + |final Object[] $keys = new Object[$numEntries]; + |final Object[] $values = new Object[$numEntries]; |$assignmentLoop |$mapData = $mapDataClass.apply($keys, $values); """.stripMargin diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index ae9f31231f5a..b93acecddce1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -59,10 +59,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper test("MapFromEntries") { def arrayType(keyType: DataType, valueType: DataType) : DataType = { - ArrayType(StructType(Seq( - StructField("a", keyType, false), - StructField("b", valueType))), - false) + ArrayType( + StructType(Seq( + StructField("a", keyType), + StructField("b", valueType))), + true) } def r(values: Any*): InternalRow = create_row(values: _*) @@ -74,6 +75,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val ai3 = Literal.create(null, aiType) val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType) val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType) + val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType) val aby = Literal.create(Seq(r(1.toByte, 10.toByte)), arrayType(ByteType, ByteType)) val ash = Literal.create(Seq(r(1.toShort, 10.toShort)), arrayType(ShortType, ShortType)) val alo = Literal.create(Seq(r(1L, 10L)), arrayType(LongType, LongType)) @@ -88,6 +90,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkExceptionInExpression[RuntimeException]( MapFromEntries(ai5), "The first field from a struct (key) can't be null.") + checkEvaluation(MapFromEntries(ai6), Map(2 -> 20)) checkEvaluation(MapFromEntries(aby), Map(1.toByte -> 10.toByte)) checkEvaluation(MapFromEntries(ash), Map(1.toShort -> 10.toShort)) checkEvaluation(MapFromEntries(alo), Map(1L -> 10L)) @@ -100,6 +103,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val as3 = Literal.create(null, asType) val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType) val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType) + val as6 = Literal.create(Seq(null, r("b", "bb"), null), asType) checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb")) checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null)) @@ -111,7 +115,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkExceptionInExpression[RuntimeException]( MapFromEntries(as5), "The first field from a struct (key) can't be null.") - + checkEvaluation(MapFromEntries(as6), Map("b" -> "bb")) } test("Sort Array") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 196194c7340d..f6537833c972 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -406,55 +406,53 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("map_from_entries function") { - def dummy_filter(c: Column): Column = c.isNull || c.isNotNull - - def arrayType(keyType: DataType, valueType: DataType) : StructType = { - StructType(Seq(StructField( - "a", - ArrayType( - StructType(Seq( - StructField("k", keyType, false), - StructField("v", valueType, true))), - false)))) - } + def dummyFilter(c: Column): Column = c.isNull || c.isNotNull + val oneRowDF = Seq(3215).toDF("i") // Test cases with primitive-type keys and values - val irdd = spark.sparkContext.parallelize(Seq( - Row(Seq(Row(1, 10), Row(2, 20), Row(3, 10))), - Row(Seq(Row(1, null), Row(2, null))), - Row(Seq.empty), - Row(null))) - val idf = spark.createDataFrame(irdd, arrayType(IntegerType, IntegerType)) + val idf = Seq( + Seq((1, 10), (2, 20), (3, 10)), + Seq((1, 10), null, (2, 20)), + Seq.empty, + null + ).toDF("a") val iExpected = Seq( Row(Map(1 -> 10, 2 -> 20, 3 -> 10)), - Row(Map(1 -> null, 2 -> null)), + Row(Map(1 -> 10, 2 -> 20)), Row(Map.empty), Row(null)) checkAnswer(idf.select(map_from_entries('a)), iExpected) checkAnswer(idf.selectExpr("map_from_entries(a)"), iExpected) - checkAnswer(idf.filter(dummy_filter('a)).select(map_from_entries('a)), iExpected) + checkAnswer(idf.filter(dummyFilter('a)).select(map_from_entries('a)), iExpected) + checkAnswer( + oneRowDF.selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq(Row(Map(1 -> null, 2 -> null))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('i)) + .selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq(Row(Map(1 -> null, 2 -> null))) + ) // Test cases with non-primitive-type keys and values - val srdd = spark.sparkContext.parallelize(Seq( - Row(Seq(Row("a", "aa"), Row("b", "bb"), Row("c", "aa"))), - Row(Seq(Row("a", null), Row("b", null))), - Row(Seq.empty), - Row(null))) - val sdf = spark.createDataFrame(srdd, arrayType(StringType, StringType)) + val sdf = Seq( + Seq(("a", "aa"), ("b", "bb"), ("c", "aa")), + Seq(("a", "aa"), null, ("b", "bb")), + Seq(("a", null), ("b", null)), + Seq.empty, + null + ).toDF("a") val sExpected = Seq( Row(Map("a" -> "aa", "b" -> "bb", "c" -> "aa")), + Row(Map("a" -> "aa", "b" -> "bb")), Row(Map("a" -> null, "b" -> null)), Row(Map.empty), Row(null)) checkAnswer(sdf.select(map_from_entries('a)), sExpected) checkAnswer(sdf.selectExpr("map_from_entries(a)"), sExpected) - checkAnswer(sdf.filter(dummy_filter('a)).select(map_from_entries('a)), sExpected) - - // Error test cases - intercept[AnalysisException](idf.selectExpr("map_from_entries(array(struct(null, 1)))")) - intercept[AnalysisException](idf.selectExpr("map_from_entries(array(struct(1, 10), null))")) + checkAnswer(sdf.filter(dummyFilter('a)).select(map_from_entries('a)), sExpected) } test("array contains function") { From 10ace84ebc42f3c39069e3c323ca4f89c227755b Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Sun, 3 Jun 2018 00:57:56 +0200 Subject: [PATCH 3/5] [SPARK-23934][SQL] Ignoring key duplicities --- .../expressions/collectionOperations.scala | 42 +++---------------- .../CollectionExpressionsSuite.scala | 8 +--- 2 files changed, 8 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 710f354d767d..db0b23f8736a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -347,10 +347,9 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { val length = arrayData.numElements() val numEntries = if (nullEntries) (0 until length).count(!arrayData.isNullAt(_)) else length val keyArray = new Array[AnyRef](numEntries) - val keySet = new OpenHashSet[AnyRef]() val valueArray = new Array[AnyRef](numEntries) - var i = 0; - var j = 0; + var i = 0 + var j = 0 while (i < length) { if (!arrayData.isNullAt(i)) { val entry = arrayData.getStruct(i, 2) @@ -358,11 +357,6 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { if (key == null) { throw new RuntimeException("The first field from a struct (key) can't be null.") } - if (keySet.contains(key)) { - throw new RuntimeException( - "The first field from a struct (key) can't produce duplicates.") - } - keySet.add(key) keyArray.update(j, key) val value = entry.get(1, dataType.valueType) valueArray.update(j, value) @@ -373,26 +367,16 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { ArrayBasedMapData(keyArray, valueArray) } - private def getHashSetDetails(): (String, String) = dataType.keyType match { - case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") - case LongType => ("$mcJ$sp", "Long") - case _ => ("", "Object") - } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { val length = ctx.freshName("length") val numEntries = ctx.freshName("numEntries") - val keySet = ctx.freshName("keySet") - val hsClass = classOf[OpenHashSet[_]].getName - val tagPrefix = "scala.reflect.ClassTag$.MODULE$." - val (hsSuffix, tagSuffix) = getHashSetDetails() val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType) val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) val code = if (isKeyPrimitive && isValuePrimitive) { - genCodeForPrimitiveElements(ctx, c, ev.value, keySet, length, numEntries) + genCodeForPrimitiveElements(ctx, c, ev.value, length, numEntries) } else { - genCodeForAnyElements(ctx, c, ev.value, keySet, length, numEntries) + genCodeForAnyElements(ctx, c, ev.value, length, numEntries) } val numEntriesAssignment = if (nullEntries) { val idx = ctx.freshName("idx") @@ -409,7 +393,6 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { s""" |final int $length = $c.numElements(); |$numEntriesAssignment - |final $hsClass$hsSuffix $keySet = new $hsClass$hsSuffix($tagPrefix$tagSuffix()); |$code """.stripMargin }) @@ -419,15 +402,12 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { ctx: CodegenContext, childVariable: String, length: String, - keySet: String, keyAssignment: (String, String) => String, valueAssignment: (String, String) => String): String = { val entry = ctx.freshName("entry") - val key = ctx.freshName("key") val i = ctx.freshName("idx") val j = ctx.freshName("idx") - val keyType = CodeGenerator.javaType(dataType.keyType) val nullEntryCheck = if (nullEntries) s"if ($childVariable.isNullAt($i)) continue;" else "" val nullKeyCheck = if (dataTypeDetails.get._2) { s""" @@ -444,13 +424,7 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { | $nullEntryCheck | InternalRow $entry = $childVariable.getStruct($i, 2); | $nullKeyCheck - | $keyType $key = ${CodeGenerator.getValue(entry, dataType.keyType, "0")}; - | if ($keySet.contains($key)) { - | throw new RuntimeException( - | "The first field from a struct (key) can't produce duplicates."); - | } - | $keySet.add($key); - | ${keyAssignment(key, j)} + | ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), j)} | ${valueAssignment(entry, j)} | $j++; |} @@ -461,7 +435,6 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { ctx: CodegenContext, childVariable: String, mapData: String, - keySet: String, length: String, numEntries: String): String = { val byteArraySize = ctx.freshName("byteArraySize") @@ -500,7 +473,6 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { ctx, childVariable, length, - keySet, keyAssignment, valueAssignment ) @@ -510,7 +482,7 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { |final long $valueSectionSize = $vByteSize; |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize; |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | ${genCodeForAnyElements(ctx, childVariable, mapData, keySet, length, numEntries)} + | ${genCodeForAnyElements(ctx, childVariable, mapData, length, numEntries)} |} else { | final byte[] $data = new byte[(int)$byteArraySize]; | UnsafeMapData $unsafeMapData = new UnsafeMapData(); @@ -530,7 +502,6 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { ctx: CodegenContext, childVariable: String, mapData: String, - keySet: String, length: String, numEntries: String): String = { val keys = ctx.freshName("keys") @@ -551,7 +522,6 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { ctx, childVariable, length, - keySet, keyAssignment, valueAssignment) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 88b41829f58c..1a3f264d0068 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -106,9 +106,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null)) checkEvaluation(MapFromEntries(ai2), Map.empty) checkEvaluation(MapFromEntries(ai3), null) - checkExceptionInExpression[RuntimeException]( - MapFromEntries(ai4), - "The first field from a struct (key) can't produce duplicates.") + checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1)) checkExceptionInExpression[RuntimeException]( MapFromEntries(ai5), "The first field from a struct (key) can't be null.") @@ -131,9 +129,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null)) checkEvaluation(MapFromEntries(as2), Map.empty) checkEvaluation(MapFromEntries(as3), null) - checkExceptionInExpression[RuntimeException]( - MapFromEntries(as4), - "The first field from a struct (key) can't produce duplicates.") + checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a")) checkExceptionInExpression[RuntimeException]( MapFromEntries(as5), "The first field from a struct (key) can't be null.") From 599656eed53222d5e243db663bf52cc3c1e802a7 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Thu, 21 Jun 2018 16:55:41 +0200 Subject: [PATCH 4/5] [SPARK-23934][SQL] Handling of null entries --- .../expressions/codegen/CodeGenerator.scala | 30 ++++++ .../expressions/collectionOperations.scala | 95 +++++++------------ .../CollectionExpressionsSuite.scala | 10 +- .../spark/sql/DataFrameFunctionsSuite.scala | 4 +- 4 files changed, 69 insertions(+), 70 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 66315e590625..a690880e5c8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -819,6 +819,36 @@ class CodegenContext { } } + /** + * Generates code to do null safe execution when accessing properties of complex + * ArrayData elements. + * + * @param nullElements used to decide whether the ArrayData might contain null or not. + * @param isNull a variable indicating whether the result will be evaluated to null or not. + * @param arrayData a variable name representing the ArrayData. + * @param execute the code that should be executed only if the ArrayData doesn't contain + * any null. + */ + def nullArrayElementsSaveExec( + nullElements: Boolean, + isNull: String, + arrayData: String)( + execute: String): String = { + val i = freshName("idx") + if (nullElements) { + s""" + |for (int $i = 0; !$isNull && $i < $arrayData.numElements(); $i++) { + | $isNull |= $arrayData.isNullAt($i); + |} + |if (!$isNull) { + | $execute + |} + """.stripMargin + } else { + execute + } + } + /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index fa69f28d9416..3afabe14606e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -500,6 +500,8 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { private def nullEntries: Boolean = dataTypeDetails.get._3 + override def nullable: Boolean = child.nullable || nullEntries + override def dataType: MapType = dataTypeDetails.get._1 override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { @@ -510,24 +512,26 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { override protected def nullSafeEval(input: Any): Any = { val arrayData = input.asInstanceOf[ArrayData] - val length = arrayData.numElements() - val numEntries = if (nullEntries) (0 until length).count(!arrayData.isNullAt(_)) else length + val numEntries = arrayData.numElements() + var i = 0 + if(nullEntries) { + while (i < numEntries) { + if (arrayData.isNullAt(i)) return null + i += 1 + } + } val keyArray = new Array[AnyRef](numEntries) val valueArray = new Array[AnyRef](numEntries) - var i = 0 - var j = 0 - while (i < length) { - if (!arrayData.isNullAt(i)) { - val entry = arrayData.getStruct(i, 2) - val key = entry.get(0, dataType.keyType) - if (key == null) { - throw new RuntimeException("The first field from a struct (key) can't be null.") - } - keyArray.update(j, key) - val value = entry.get(1, dataType.valueType) - valueArray.update(j, value) - j += 1 + i = 0 + while (i < numEntries) { + val entry = arrayData.getStruct(i, 2) + val key = entry.get(0, dataType.keyType) + if (key == null) { + throw new RuntimeException("The first field from a struct (key) can't be null.") } + keyArray.update(i, key) + val value = entry.get(1, dataType.valueType) + valueArray.update(i, value) i += 1 } ArrayBasedMapData(keyArray, valueArray) @@ -535,46 +539,33 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { - val length = ctx.freshName("length") val numEntries = ctx.freshName("numEntries") val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType) val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) val code = if (isKeyPrimitive && isValuePrimitive) { - genCodeForPrimitiveElements(ctx, c, ev.value, length, numEntries) + genCodeForPrimitiveElements(ctx, c, ev.value, numEntries) } else { - genCodeForAnyElements(ctx, c, ev.value, length, numEntries) + genCodeForAnyElements(ctx, c, ev.value, numEntries) } - val numEntriesAssignment = if (nullEntries) { - val idx = ctx.freshName("idx") + ctx.nullArrayElementsSaveExec(nullEntries, ev.isNull, c) { s""" - |int $numEntries = 0; - |for (int $idx = 0; $idx < $length; $idx++) { - | if (!$c.isNullAt($idx)) $numEntries++; - |} + |final int $numEntries = $c.numElements(); + |$code """.stripMargin - } else { - s"final int $numEntries = $length;" } - - s""" - |final int $length = $c.numElements(); - |$numEntriesAssignment - |$code - """.stripMargin }) } private def genCodeForAssignmentLoop( ctx: CodegenContext, childVariable: String, - length: String, + mapData: String, + numEntries: String, keyAssignment: (String, String) => String, valueAssignment: (String, String) => String): String = { val entry = ctx.freshName("entry") val i = ctx.freshName("idx") - val j = ctx.freshName("idx") - val nullEntryCheck = if (nullEntries) s"if ($childVariable.isNullAt($i)) continue;" else "" val nullKeyCheck = if (dataTypeDetails.get._2) { s""" |if ($entry.isNullAt(0)) { @@ -586,13 +577,11 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { } s""" - |for (int $i = 0, $j = 0; $i < $length; $i++) { - | $nullEntryCheck + |for (int $i = 0; $i < $numEntries; $i++) { | InternalRow $entry = $childVariable.getStruct($i, 2); | $nullKeyCheck - | ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), j)} - | ${valueAssignment(entry, j)} - | $j++; + | ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), i)} + | ${valueAssignment(entry, i)} |} """.stripMargin } @@ -601,7 +590,6 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { ctx: CodegenContext, childVariable: String, mapData: String, - length: String, numEntries: String): String = { val byteArraySize = ctx.freshName("byteArraySize") val keySectionSize = ctx.freshName("keySectionSize") @@ -638,7 +626,8 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { val assignmentLoop = genCodeForAssignmentLoop( ctx, childVariable, - length, + mapData, + numEntries, keyAssignment, valueAssignment ) @@ -648,7 +637,7 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { |final long $valueSectionSize = $vByteSize; |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize; |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | ${genCodeForAnyElements(ctx, childVariable, mapData, length, numEntries)} + | ${genCodeForAnyElements(ctx, childVariable, mapData, numEntries)} |} else { | final byte[] $data = new byte[(int)$byteArraySize]; | UnsafeMapData $unsafeMapData = new UnsafeMapData(); @@ -668,7 +657,6 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { ctx: CodegenContext, childVariable: String, mapData: String, - length: String, numEntries: String): String = { val keys = ctx.freshName("keys") val values = ctx.freshName("values") @@ -687,7 +675,8 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { val assignmentLoop = genCodeForAssignmentLoop( ctx, childVariable, - length, + mapData, + numEntries, keyAssignment, valueAssignment) @@ -2218,24 +2207,10 @@ case class Flatten(child: Expression) extends UnaryExpression { } else { genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) } - if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code + ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, c)(code) }) } - private def nullElementsProtection( - ev: ExprCode, - childVariableName: String, - coreLogic: String): String = { - s""" - |for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { - | ${ev.isNull} |= $childVariableName.isNullAt(z); - |} - |if (!${ev.isNull}) { - | $coreLogic - |} - """.stripMargin - } - private def genCodeForNumberOfElements( ctx: CodegenContext, childVariableName: String) : (String, String) = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index b11b2fdf5534..5b8cf5128fe2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -99,9 +99,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType) val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType) val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType) - val aby = Literal.create(Seq(r(1.toByte, 10.toByte)), arrayType(ByteType, ByteType)) - val ash = Literal.create(Seq(r(1.toShort, 10.toShort)), arrayType(ShortType, ShortType)) - val alo = Literal.create(Seq(r(1L, 10L)), arrayType(LongType, LongType)) checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20)) checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null)) @@ -111,10 +108,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkExceptionInExpression[RuntimeException]( MapFromEntries(ai5), "The first field from a struct (key) can't be null.") - checkEvaluation(MapFromEntries(ai6), Map(2 -> 20)) - checkEvaluation(MapFromEntries(aby), Map(1.toByte -> 10.toByte)) - checkEvaluation(MapFromEntries(ash), Map(1.toShort -> 10.toShort)) - checkEvaluation(MapFromEntries(alo), Map(1L -> 10L)) + checkEvaluation(MapFromEntries(ai6), null) // Non-primitive-type keys and values val asType = arrayType(StringType, StringType) @@ -134,7 +128,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkExceptionInExpression[RuntimeException]( MapFromEntries(as5), "The first field from a struct (key) can't be null.") - checkEvaluation(MapFromEntries(as6), Map("b" -> "bb")) + checkEvaluation(MapFromEntries(as6), null) } test("Sort Array") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 46d91461f410..25fdbab74512 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -646,7 +646,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).toDF("a") val iExpected = Seq( Row(Map(1 -> 10, 2 -> 20, 3 -> 10)), - Row(Map(1 -> 10, 2 -> 20)), + Row(null), Row(Map.empty), Row(null)) @@ -673,7 +673,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).toDF("a") val sExpected = Seq( Row(Map("a" -> "aa", "b" -> "bb", "c" -> "aa")), - Row(Map("a" -> "aa", "b" -> "bb")), + Row(null), Row(Map("a" -> null, "b" -> null)), Row(Map.empty), Row(null)) From 4eaedc50f92a3dd1ee2100fbbd5ac951344ece75 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Thu, 21 Jun 2018 17:47:38 +0200 Subject: [PATCH 5/5] [SPARK-23934][SQL] Fixing scala style --- .../expressions/codegen/CodeGenerator.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a690880e5c8c..4cc0968911cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -820,15 +820,15 @@ class CodegenContext { } /** - * Generates code to do null safe execution when accessing properties of complex - * ArrayData elements. - * - * @param nullElements used to decide whether the ArrayData might contain null or not. - * @param isNull a variable indicating whether the result will be evaluated to null or not. - * @param arrayData a variable name representing the ArrayData. - * @param execute the code that should be executed only if the ArrayData doesn't contain - * any null. - */ + * Generates code to do null safe execution when accessing properties of complex + * ArrayData elements. + * + * @param nullElements used to decide whether the ArrayData might contain null or not. + * @param isNull a variable indicating whether the result will be evaluated to null or not. + * @param arrayData a variable name representing the ArrayData. + * @param execute the code that should be executed only if the ArrayData doesn't contain + * any null. + */ def nullArrayElementsSaveExec( nullElements: Boolean, isNull: String,