diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index c597a2a70944..ea4dee174e74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -353,7 +353,7 @@ object MapObjects { val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - MapObjects(loopVar, function(loopVar), inputData) + MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData) } } @@ -365,14 +365,20 @@ object MapObjects { * The following collection ObjectTypes are currently supported: * Seq, Array, ArrayData, java.util.List * - * @param loopVar A place holder that used as the loop variable when iterate the collection, and - * used as input for the `lambdaFunction`. It also carries the element type info. + * @param loopValue the name of the loop variable that used when iterate the collection, and used + * as input for the `lambdaFunction` + * @param loopIsNull the nullity of the loop variable that used when iterate the collection, and + * used as input for the `lambdaFunction` + * @param loopVarDataType the data type of the loop variable that used when iterate the collection, + * and used as input for the `lambdaFunction` * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. * @param inputData An expression that when evaluated returns a collection object. */ case class MapObjects private( - loopVar: LambdaVariable, + loopValue: String, + loopIsNull: String, + loopVarDataType: DataType, lambdaFunction: Expression, inputData: Expression) extends Expression with NonSQLExpression { @@ -386,9 +392,9 @@ case class MapObjects private( override def dataType: DataType = ArrayType(lambdaFunction.dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val elementJavaType = ctx.javaType(loopVar.dataType) - ctx.addMutableState("boolean", loopVar.isNull, "") - ctx.addMutableState(elementJavaType, loopVar.value, "") + val elementJavaType = ctx.javaType(loopVarDataType) + ctx.addMutableState("boolean", loopIsNull, "") + ctx.addMutableState(elementJavaType, loopValue, "") val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -443,11 +449,11 @@ case class MapObjects private( } val loopNullCheck = inputData.dataType match { - case _: ArrayType => s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" + case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" // The element of primitive array will never be null. case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => - s"${loopVar.isNull} = false" - case _ => s"${loopVar.isNull} = ${loopVar.value} == null;" + s"$loopIsNull = false" + case _ => s"$loopIsNull = $loopValue == null;" } val code = s""" @@ -462,7 +468,7 @@ case class MapObjects private( int $loopIndex = 0; while ($loopIndex < $dataLength) { - ${loopVar.value} = ($elementJavaType) ($getLoopVar); + $loopValue = ($elementJavaType) ($getLoopVar); $loopNullCheck ${genFunction.code} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index f955120dc543..32fcf84b02f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -74,6 +74,16 @@ object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { } +object MapTypeBufferAgg extends Aggregator[Int, Map[Int, Int], Int] { + override def zero: Map[Int, Int] = Map.empty + override def reduce(b: Map[Int, Int], a: Int): Map[Int, Int] = b + override def finish(reduction: Map[Int, Int]): Int = 1 + override def merge(b1: Map[Int, Int], b2: Map[Int, Int]): Map[Int, Int] = b1 + override def bufferEncoder: Encoder[Map[Int, Int]] = ExpressionEncoder() + override def outputEncoder: Encoder[Int] = ExpressionEncoder() +} + + object NameAgg extends Aggregator[AggData, String, String] { def zero: String = "" def reduce(b: String, a: AggData): String = a.b + b @@ -290,4 +300,9 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ds.groupByKey(_.a).agg(NullResultAgg.toColumn), 1 -> AggData(1, "one"), 2 -> null) } + + test("SPARK-16100: use Map as the buffer type of Aggregator") { + val ds = Seq(1, 2, 3).toDS() + checkDataset(ds.select(MapTypeBufferAgg.toColumn), 1) + } }