Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ object MapObjects {
val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
MapObjects(loopVar, function(loopVar), inputData)
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData)
}
}

Expand All @@ -365,14 +365,20 @@ object MapObjects {
* The following collection ObjectTypes are currently supported:
* Seq, Array, ArrayData, java.util.List
*
* @param loopVar A place holder that used as the loop variable when iterate the collection, and
* used as input for the `lambdaFunction`. It also carries the element type info.
* @param loopValue the name of the loop variable that used when iterate the collection, and used
* as input for the `lambdaFunction`
* @param loopIsNull the nullity of the loop variable that used when iterate the collection, and
* used as input for the `lambdaFunction`
* @param loopVarDataType the data type of the loop variable that used when iterate the collection,
* and used as input for the `lambdaFunction`
* @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
* to handle collection elements.
* @param inputData An expression that when evaluated returns a collection object.
*/
case class MapObjects private(
loopVar: LambdaVariable,
loopValue: String,
loopIsNull: String,
loopVarDataType: DataType,
lambdaFunction: Expression,
inputData: Expression) extends Expression with NonSQLExpression {

Expand All @@ -386,9 +392,9 @@ case class MapObjects private(
override def dataType: DataType = ArrayType(lambdaFunction.dataType)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val elementJavaType = ctx.javaType(loopVar.dataType)
ctx.addMutableState("boolean", loopVar.isNull, "")
ctx.addMutableState(elementJavaType, loopVar.value, "")
val elementJavaType = ctx.javaType(loopVarDataType)
ctx.addMutableState("boolean", loopIsNull, "")
ctx.addMutableState(elementJavaType, loopValue, "")
val genInputData = inputData.genCode(ctx)
val genFunction = lambdaFunction.genCode(ctx)
val dataLength = ctx.freshName("dataLength")
Expand Down Expand Up @@ -443,11 +449,11 @@ case class MapObjects private(
}

val loopNullCheck = inputData.dataType match {
case _: ArrayType => s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
// The element of primitive array will never be null.
case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive =>
s"${loopVar.isNull} = false"
case _ => s"${loopVar.isNull} = ${loopVar.value} == null;"
s"$loopIsNull = false"
case _ => s"$loopIsNull = $loopValue == null;"
}

val code = s"""
Expand All @@ -462,7 +468,7 @@ case class MapObjects private(

int $loopIndex = 0;
while ($loopIndex < $dataLength) {
${loopVar.value} = ($elementJavaType) ($getLoopVar);
$loopValue = ($elementJavaType) ($getLoopVar);
$loopNullCheck

${genFunction.code}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] {
}


object MapTypeBufferAgg extends Aggregator[Int, Map[Int, Int], Int] {
override def zero: Map[Int, Int] = Map.empty
override def reduce(b: Map[Int, Int], a: Int): Map[Int, Int] = b
override def finish(reduction: Map[Int, Int]): Int = 1
override def merge(b1: Map[Int, Int], b2: Map[Int, Int]): Map[Int, Int] = b1
override def bufferEncoder: Encoder[Map[Int, Int]] = ExpressionEncoder()
override def outputEncoder: Encoder[Int] = ExpressionEncoder()
}


object NameAgg extends Aggregator[AggData, String, String] {
def zero: String = ""
def reduce(b: String, a: AggData): String = a.b + b
Expand Down Expand Up @@ -290,4 +300,9 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
ds.groupByKey(_.a).agg(NullResultAgg.toColumn),
1 -> AggData(1, "one"), 2 -> null)
}

test("SPARK-16100: use Map as the buffer type of Aggregator") {
val ds = Seq(1, 2, 3).toDS()
checkDataset(ds.select(MapTypeBufferAgg.toColumn), 1)
}
}