diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 55ab6b3358e3..c3547ae0a1b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -54,6 +54,7 @@ private[spark] object SQLConf { // considered hints and may be ignored by future versions of Spark SQL. val EXTERNAL_SORT = "spark.sql.planner.externalSort" val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin" + val SORTMERGE_AGGREGATE = "spark.sql.planner.sortMergeAggregate" // This is only used for the thriftserver val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" @@ -170,6 +171,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf { */ private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean + /** + * Sort merge aggregate would sort the group key first, then iterate a group to get aggregation, + * then move to next group. Using sort merge aggregate can save a lot of memory usage compared + * to HashAggregate. + */ + private[spark] def sortMergeAggregateEnabled: Boolean = + getConf(SORTMERGE_AGGREGATE, "false").toBoolean + /** * When set to true, Spark SQL will use the Janino at runtime to generate custom bytecode * that evaluates expressions found in queries. In general this custom code runs much faster diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index ba2c8f53d702..394c8131e782 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -17,12 +17,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.TaskContext -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types._ case class AggregateEvaluation( @@ -31,329 +27,174 @@ case class AggregateEvaluation( update: Seq[Expression], result: Expression) -/** - * :: DeveloperApi :: - * Alternate version of aggregation that leverages projection and thus code generation. - * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto - * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported. - * - * @param partial if true then aggregation is done partially on local data without shuffling to - * ensure all values where `groupingExpressions` are equal are present. - * @param groupingExpressions expressions that are evaluated to determine grouping. - * @param aggregateExpressions expressions that are computed for each group. - * @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used. - * @param child the input data source. - */ -@DeveloperApi -case class GeneratedAggregate( - partial: Boolean, - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - unsafeEnabled: Boolean, - child: SparkPlan) - extends UnaryNode { +trait GeneratedAggregate { + self: SparkPlan => - override def requiredChildDistribution: Seq[Distribution] = - if (partial) { - UnspecifiedDistribution :: Nil - } else { - if (groupingExpressions == Nil) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } - } + val groupingExpressions: Seq[Expression] + val aggregateExpressions: Seq[NamedExpression] + val unsafeEnabled: Boolean + val child: SparkPlan override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) - protected override def doExecute(): RDD[InternalRow] = { - val aggregatesToCompute = aggregateExpressions.flatMap { a => - a.collect { case agg: AggregateExpression => agg} - } - - // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite - // (in test "aggregation with codegen"). - val computeFunctions = aggregatesToCompute.map { - case c @ Count(expr) => - // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its - // UnscaledValue will be null if and only if x is null; helps with Average on decimals - val toCount = expr match { - case UnscaledValue(e) => e - case _ => expr - } - val currentCount = AttributeReference("currentCount", LongType, nullable = false)() - val initialValue = Literal(0L) - val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) - val result = currentCount - - AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - - case s @ Sum(expr) => - val calcType = - expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited - case _ => - expr.dataType - } - - val currentSum = AttributeReference("currentSum", calcType, nullable = true)() - val initialValue = Literal.create(null, calcType) - - // Coalesce avoids double calculation... - // but really, common sub expression elimination would be better.... - val zero = Cast(Literal(0), calcType) - val updateFunction = Coalesce( - Add( - Coalesce(currentSum :: zero :: Nil), - Cast(expr, calcType) - ) :: currentSum :: zero :: Nil) - val result = - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(currentSum, s.dataType) - case _ => currentSum - } - - AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - - case cs @ CombineSum(expr) => - val calcType = - expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited - case _ => - expr.dataType - } - - val currentSum = AttributeReference("currentSum", calcType, nullable = true)() - val initialValue = Literal.create(null, calcType) - - // Coalesce avoids double calculation... - // but really, common sub expression elimination would be better.... - val zero = Cast(Literal(0), calcType) - // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its - // UnscaledValue will be null if and only if x is null; helps with Average on decimals - val actualExpr = expr match { - case UnscaledValue(e) => e - case _ => expr - } - // partial sum result can be null only when no input rows present - val updateFunction = If( - IsNotNull(actualExpr), - Coalesce( - Add( - Coalesce(currentSum :: zero :: Nil), - Cast(expr, calcType)) :: currentSum :: zero :: Nil), - currentSum) - - val result = - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(currentSum, cs.dataType) - case _ => currentSum - } - - AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - - case m @ Max(expr) => - val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() - val initialValue = Literal.create(null, expr.dataType) - val updateMax = MaxOf(currentMax, expr) - - AggregateEvaluation( - currentMax :: Nil, - initialValue :: Nil, - updateMax :: Nil, - currentMax) - - case m @ Min(expr) => - val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)() - val initialValue = Literal.create(null, expr.dataType) - val updateMin = MinOf(currentMin, expr) - - AggregateEvaluation( - currentMin :: Nil, - initialValue :: Nil, - updateMin :: Nil, - currentMin) - - case CollectHashSet(Seq(expr)) => - val set = - AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)() - val initialValue = NewSet(expr.dataType) - val addToSet = AddItemToSet(expr, set) - - AggregateEvaluation( - set :: Nil, - initialValue :: Nil, - addToSet :: Nil, - set) - - case CombineSetsAndCount(inputSet) => - val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType - val set = - AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)() - val initialValue = NewSet(elementType) - val collectSets = CombineSets(set, inputSet) - - AggregateEvaluation( - set :: Nil, - initialValue :: Nil, - collectSets :: Nil, - CountSet(set)) - - case o => sys.error(s"$o can't be codegened.") - } - - val computationSchema = computeFunctions.flatMap(_.schema) - - val resultMap: Map[TreeNodeRef, Expression] = - aggregatesToCompute.zip(computeFunctions).map { - case (agg, func) => new TreeNodeRef(agg) -> func.result - }.toMap - - val namedGroups = groupingExpressions.zipWithIndex.map { - case (ne: NamedExpression, _) => (ne, ne.toAttribute) - case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute) - } - - // The set of expressions that produce the final output given the aggregation buffer and the - // grouping expressions. - val resultExpressions = aggregateExpressions.map(_.transform { - case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e)) - case e: Expression => - namedGroups.collectFirst { - case (expr, attr) if expr semanticEquals e => attr - }.getOrElse(e) - }) + @transient protected lazy val aggregatesToCompute = aggregateExpressions.flatMap { a => + a.collect { case agg: AggregateExpression => agg} + } - val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema) + @transient protected lazy val namedGroups = groupingExpressions.zipWithIndex.map { + case (ne: NamedExpression, _) => (ne, ne.toAttribute) + case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute) + } - val groupKeySchema: StructType = { - val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) => - // This is a dummy field name - StructField(idx.toString, expr.dataType, expr.nullable) + // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite + // (in test "aggregation with codegen"). + @transient protected lazy val computeFunctions = aggregatesToCompute.map { + case c @ Count(expr) => + // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its + // UnscaledValue will be null if and only if x is null; helps with Average on decimals + val toCount = expr match { + case UnscaledValue(e) => e + case _ => expr } - StructType(fields) - } - - val schemaSupportsUnsafe: Boolean = { - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) - } - - child.execute().mapPartitions { iter => - // Builds a new custom class for holding the results of aggregation for a group. - val initialValues = computeFunctions.flatMap(_.initialValues) - val newAggregationBuffer = newProjection(initialValues, child.output) - log.info(s"Initial values: ${initialValues.mkString(",")}") - - // A projection that computes the group given an input tuple. - val groupProjection = newProjection(groupingExpressions, child.output) - log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}") - - // A projection that is used to update the aggregate values for a group given a new tuple. - // This projection should be targeted at the current values for the group and then applied - // to a joined row of the current values with the new input row. - val updateExpressions = computeFunctions.flatMap(_.update) - val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output - val updateProjection = newMutableProjection(updateExpressions, updateSchema)() - log.info(s"Update Expressions: ${updateExpressions.mkString(",")}") - - // A projection that produces the final result, given a computation. - val resultProjectionBuilder = - newMutableProjection( - resultExpressions, - namedGroups.map(_._2) ++ computationSchema) - log.info(s"Result Projection: ${resultExpressions.mkString(",")}") - - val joinedRow = new JoinedRow3 - - if (groupingExpressions.isEmpty) { - // TODO: Codegening anything other than the updateProjection is probably over kill. - val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] - var currentRow: InternalRow = null - updateProjection.target(buffer) - - while (iter.hasNext) { - currentRow = iter.next() - updateProjection(joinedRow(buffer, currentRow)) + val currentCount = AttributeReference("currentCount", LongType, nullable = false)() + val initialValue = Literal(0L) + val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) + val result = currentCount + + AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + + case s @ Sum(expr) => + val calcType = + expr.dataType match { + case DecimalType.Fixed(_, _) => + DecimalType.Unlimited + case _ => + expr.dataType } - val resultProjection = resultProjectionBuilder() - Iterator(resultProjection(buffer)) - } else if (unsafeEnabled && schemaSupportsUnsafe) { - log.info("Using Unsafe-based aggregator") - val aggregationMap = new UnsafeFixedWidthAggregationMap( - newAggregationBuffer(EmptyRow), - aggregationBufferSchema, - groupKeySchema, - TaskContext.get.taskMemoryManager(), - 1024 * 16, // initial capacity - false // disable tracking of performance metrics - ) - - while (iter.hasNext) { - val currentRow: InternalRow = iter.next() - val groupKey: InternalRow = groupProjection(currentRow) - val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey) - updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) + val currentSum = AttributeReference("currentSum", calcType, nullable = true)() + val initialValue = Literal.create(null, calcType) + + // Coalesce avoids double calculation... + // but really, common sub expression elimination would be better.... + val zero = Cast(Literal(0), calcType) + val updateFunction = Coalesce( + Add( + Coalesce(currentSum :: zero :: Nil), + Cast(expr, calcType) + ) :: currentSum :: zero :: Nil) + val result = + expr.dataType match { + case DecimalType.Fixed(_, _) => + Cast(currentSum, s.dataType) + case _ => currentSum } - new Iterator[InternalRow] { - private[this] val mapIterator = aggregationMap.iterator() - private[this] val resultProjection = resultProjectionBuilder() - - def hasNext: Boolean = mapIterator.hasNext + AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - def next(): InternalRow = { - val entry = mapIterator.next() - val result = resultProjection(joinedRow(entry.key, entry.value)) - if (hasNext) { - result - } else { - // This is the last element in the iterator, so let's free the buffer. Before we do, - // though, we need to make a defensive copy of the result so that we don't return an - // object that might contain dangling pointers to the freed memory - val resultCopy = result.copy() - aggregationMap.free() - resultCopy - } - } + case cs @ CombineSum(expr) => + val calcType = + expr.dataType match { + case DecimalType.Fixed(_, _) => + DecimalType.Unlimited + case _ => + expr.dataType } - } else { - if (unsafeEnabled) { - log.info("Not using Unsafe-based aggregator because it is not supported for this schema") - } - val buffers = new java.util.HashMap[InternalRow, MutableRow]() - var currentRow: InternalRow = null - while (iter.hasNext) { - currentRow = iter.next() - val currentGroup = groupProjection(currentRow) - var currentBuffer = buffers.get(currentGroup) - if (currentBuffer == null) { - currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] - buffers.put(currentGroup, currentBuffer) - } - // Target the projection at the current aggregation buffer and then project the updated - // values. - updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow)) + val currentSum = AttributeReference("currentSum", calcType, nullable = true)() + val initialValue = Literal.create(null, calcType) + + // Coalesce avoids double calculation... + // but really, common sub expression elimination would be better.... + val zero = Cast(Literal(0), calcType) + // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its + // UnscaledValue will be null if and only if x is null; helps with Average on decimals + val actualExpr = expr match { + case UnscaledValue(e) => e + case _ => expr + } + // partial sum result can be null only when no input rows present + val updateFunction = If( + IsNotNull(actualExpr), + Coalesce( + Add( + Coalesce(currentSum :: zero :: Nil), + Cast(expr, calcType)) :: currentSum :: zero :: Nil), + currentSum) + + val result = + expr.dataType match { + case DecimalType.Fixed(_, _) => + Cast(currentSum, cs.dataType) + case _ => currentSum } - new Iterator[InternalRow] { - private[this] val resultIterator = buffers.entrySet.iterator() - private[this] val resultProjection = resultProjectionBuilder() - - def hasNext: Boolean = resultIterator.hasNext - - def next(): InternalRow = { - val currentGroup = resultIterator.next() - resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) - } - } - } - } + AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + + case m @ Max(expr) => + val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() + val initialValue = Literal.create(null, expr.dataType) + val updateMax = MaxOf(currentMax, expr) + + AggregateEvaluation( + currentMax :: Nil, + initialValue :: Nil, + updateMax :: Nil, + currentMax) + + case m @ Min(expr) => + val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)() + val initialValue = Literal.create(null, expr.dataType) + val updateMin = MinOf(currentMin, expr) + + AggregateEvaluation( + currentMin :: Nil, + initialValue :: Nil, + updateMin :: Nil, + currentMin) + + case CollectHashSet(Seq(expr)) => + val set = + AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)() + val initialValue = NewSet(expr.dataType) + val addToSet = AddItemToSet(expr, set) + + AggregateEvaluation( + set :: Nil, + initialValue :: Nil, + addToSet :: Nil, + set) + + case CombineSetsAndCount(inputSet) => + val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType + val set = + AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)() + val initialValue = NewSet(elementType) + val collectSets = CombineSets(set, inputSet) + + AggregateEvaluation( + set :: Nil, + initialValue :: Nil, + collectSets :: Nil, + CountSet(set)) + + case o => sys.error(s"$o can't be codegened.") } + + @transient protected lazy val computationSchema = computeFunctions.flatMap(_.schema) + + @transient protected lazy val resultMap: Map[TreeNodeRef, Expression] = + aggregatesToCompute.zip(computeFunctions).map { + case (agg, func) => new TreeNodeRef(agg) -> func.result + }.toMap + + // The set of expressions that produce the final output given the aggregation buffer and the + // grouping expressions. + @transient protected lazy val resultExpressions = aggregateExpressions.map(_.transform { + case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e)) + case e: Expression => + namedGroups.collectFirst { + case (expr, attr) if expr semanticEquals e => attr + }.getOrElse(e) + }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HashGeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HashGeneratedAggregate.scala new file mode 100644 index 000000000000..543bf2e545b1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HashGeneratedAggregate.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.TaskContext +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types._ + +/** + * :: DeveloperApi :: + * Alternate version of hash aggregation that leverages projection and thus code generation. + * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto + * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported. + * + * @param partial if true then aggregation is done partially on local data without shuffling to + * ensure all values where `groupingExpressions` are equal are present. + * @param groupingExpressions expressions that are evaluated to determine grouping. + * @param aggregateExpressions expressions that are computed for each group. + * @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used. + * @param child the input data source. + */ +@DeveloperApi +case class HashGeneratedAggregate( + partial: Boolean, + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + unsafeEnabled: Boolean, + child: SparkPlan) + extends UnaryNode with GeneratedAggregate { + + override def requiredChildDistribution: Seq[Distribution] = + if (partial) { + UnspecifiedDistribution :: Nil + } else { + if (groupingExpressions == Nil) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + protected override def doExecute(): RDD[InternalRow] = { + + val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema) + + val groupKeySchema: StructType = { + val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) => + // This is a dummy field name + StructField(idx.toString, expr.dataType, expr.nullable) + } + StructType(fields) + } + + val schemaSupportsUnsafe: Boolean = { + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) + } + + child.execute().mapPartitions { iter => + // Builds a new custom class for holding the results of aggregation for a group. + val initialValues = computeFunctions.flatMap(_.initialValues) + val newAggregationBuffer = newProjection(initialValues, child.output) + log.info(s"Initial values: ${initialValues.mkString(",")}") + + // A projection that computes the group given an input tuple. + val groupProjection = newProjection(groupingExpressions, child.output) + log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}") + + // A projection that is used to update the aggregate values for a group given a new tuple. + // This projection should be targeted at the current values for the group and then applied + // to a joined row of the current values with the new input row. + val updateExpressions = computeFunctions.flatMap(_.update) + val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output + val updateProjection = newMutableProjection(updateExpressions, updateSchema)() + log.info(s"Update Expressions: ${updateExpressions.mkString(",")}") + + // A projection that produces the final result, given a computation. + val resultProjectionBuilder = + newMutableProjection( + resultExpressions, + namedGroups.map(_._2) ++ computationSchema) + log.info(s"Result Projection: ${resultExpressions.mkString(",")}") + + val joinedRow = new JoinedRow3 + + if (groupingExpressions.isEmpty) { + // TODO: Codegening anything other than the updateProjection is probably over kill. + val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] + var currentRow: InternalRow = null + updateProjection.target(buffer) + + while (iter.hasNext) { + currentRow = iter.next() + updateProjection(joinedRow(buffer, currentRow)) + } + + val resultProjection = resultProjectionBuilder() + Iterator(resultProjection(buffer)) + } else if (unsafeEnabled && schemaSupportsUnsafe) { + log.info("Using Unsafe-based aggregator") + val aggregationMap = new UnsafeFixedWidthAggregationMap( + newAggregationBuffer(EmptyRow), + aggregationBufferSchema, + groupKeySchema, + TaskContext.get.taskMemoryManager(), + 1024 * 16, // initial capacity + false // disable tracking of performance metrics + ) + + while (iter.hasNext) { + val currentRow: InternalRow = iter.next() + val groupKey: InternalRow = groupProjection(currentRow) + val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey) + updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) + } + + new Iterator[InternalRow] { + private[this] val mapIterator = aggregationMap.iterator() + private[this] val resultProjection = resultProjectionBuilder() + + def hasNext: Boolean = mapIterator.hasNext + + def next(): InternalRow = { + val entry = mapIterator.next() + val result = resultProjection(joinedRow(entry.key, entry.value)) + if (hasNext) { + result + } else { + // This is the last element in the iterator, so let's free the buffer. Before we do, + // though, we need to make a defensive copy of the result so that we don't return an + // object that might contain dangling pointers to the freed memory + val resultCopy = result.copy() + aggregationMap.free() + resultCopy + } + } + } + } else { + if (unsafeEnabled) { + log.info("Not using Unsafe-based aggregator because it is not supported for this schema") + } + val buffers = new java.util.HashMap[InternalRow, MutableRow]() + + var currentRow: InternalRow = null + while (iter.hasNext) { + currentRow = iter.next() + val currentGroup = groupProjection(currentRow) + var currentBuffer = buffers.get(currentGroup) + if (currentBuffer == null) { + currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] + buffers.put(currentGroup, currentBuffer) + } + // Target the projection at the current aggregation buffer and then project the updated + // values. + updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow)) + } + + new Iterator[InternalRow] { + private[this] val resultIterator = buffers.entrySet.iterator() + private[this] val resultProjection = resultProjectionBuilder() + + def hasNext: Boolean = resultIterator.hasNext + + def next(): InternalRow = { + val currentGroup = resultIterator.next() + resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) + } + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortMergeAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortMergeAggregate.scala new file mode 100644 index 000000000000..729e68310eff --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortMergeAggregate.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} + +/** + * :: DeveloperApi :: + * SortMerge group input data by `groupingExpressions` and computes the `aggregateExpressions` + * for each group. + * + * @param groupingExpressions expressions that are evaluated to determine grouping. + * @param aggregateExpressions expressions that are computed for each group. + * @param child the input data source. + */ +@DeveloperApi +case class SortMergeAggregate( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingExpressions) :: Nil + + // this is to manually construct an ordering that can be used to compare keys + private val keyOrdering: RowOrdering = RowOrdering.forSchema(groupingExpressions.map(_.dataType)) + + override def outputOrdering: Seq[SortOrder] = requiredOrders(groupingExpressions) + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(groupingExpressions) :: Nil + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = + groupingExpressions.map(SortOrder(_, Ascending)) + + /** + * An aggregate that needs to be computed for each row in a group. + * + * @param unbound Unbound version of this aggregate, used for result substitution. + * @param aggregate A bound copy of this aggregate used to create a new aggregation buffer. + * @param resultAttribute An attribute used to refer to the result of this aggregate in the final + * output. + */ + case class ComputedAggregate( + unbound: AggregateExpression, + aggregate: AggregateExpression, + resultAttribute: AttributeReference) + + /** A list of aggregates that need to be computed for each group. */ + private[this] val computedAggregates = aggregateExpressions.flatMap { agg => + agg.collect { + case a: AggregateExpression => + ComputedAggregate( + a, + BindReferences.bindReference(a, child.output), + AttributeReference(s"aggResult:$a", a.dataType, a.nullable)()) + } + }.toArray + + /** The schema of the result of all aggregate evaluations */ + private[this] val computedSchema = computedAggregates.map(_.resultAttribute) + + /** Creates a new aggregate buffer for a group. */ + private[this] def newAggregateBuffer(): Array[AggregateFunction] = { + val buffer = new Array[AggregateFunction](computedAggregates.length) + var i = 0 + while (i < computedAggregates.length) { + buffer(i) = computedAggregates(i).aggregate.newInstance() + i += 1 + } + buffer + } + + /** Named attributes used to substitute grouping attributes into the final result. */ + private[this] val namedGroups = groupingExpressions.map { + case ne: NamedExpression => ne -> ne.toAttribute + case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute + } + + /** + * A map of substitutions that are used to insert the aggregate expressions and grouping + * expression into the final result expression. + */ + private[this] val resultMap = + (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute } ++ namedGroups).toMap + + /** + * Substituted version of aggregateExpressions expressions which are used to compute final + * output rows given a group and the result of all aggregate computations. + */ + private[this] val resultExpressions = aggregateExpressions.map { agg => + agg.transform { + case e: Expression if resultMap.contains(e) => resultMap(e) + } + } + + protected override def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + new Iterator[InternalRow] { + + private[this] var currentElement: InternalRow = _ + private[this] var nextElement: InternalRow = _ + private[this] var currentKey: InternalRow = _ + private[this] var nextKey: InternalRow = _ + private[this] val groupingProjection = + new InterpretedMutableProjection(groupingExpressions, child.output) + private[this] var currentBuffer: Array[AggregateFunction] = _ + private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) + private[this] val resultProjection = + new InterpretedMutableProjection( + resultExpressions, computedSchema ++ namedGroups.map(_._2)) + private[this] val joinedRow = new JoinedRow4 + + initialize() + + private def initialize() = { + if (iter.hasNext) { + currentElement = iter.next() + currentKey = groupingProjection(currentElement).copy() + } else { + currentElement = null + } + } + + override final def hasNext: Boolean = { + if (currentElement != null) { + currentBuffer = newAggregateBuffer() + var i = 0 + while (i < currentBuffer.length) { + currentBuffer(i).update(currentElement) + i += 1 + } + var stop: Boolean = false + while (!stop) { + if (iter.hasNext) { + nextElement = iter.next() + nextKey = groupingProjection(nextElement).copy() + stop = keyOrdering.compare(currentKey, nextKey) != 0 + if (!stop) { + var i = 0 + while (i < currentBuffer.length) { + currentBuffer(i).update(nextElement) + i += 1 + } + } + } else { + nextElement = null + stop = true + } + } + true + } else { + false + } + } + + override final def next(): InternalRow = { + val currentGroup = currentKey + currentElement = nextElement + currentKey = nextKey + var i = 0 + while (i < currentBuffer.length) { + aggregateResults(i) = currentBuffer(i).eval(EmptyRow) + i += 1 + } + resultProjection(joinedRow(aggregateResults, currentGroup)) + } + } + } + + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortMergeGeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortMergeGeneratedAggregate.scala new file mode 100644 index 000000000000..656d47967e7c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortMergeGeneratedAggregate.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.TaskContext +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.trees._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types._ + +/** + * :: DeveloperApi :: + * Alternate version of sort merge aggregation that leverages projection and thus code generation. + * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto + * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported. + * + * @param groupingExpressions expressions that are evaluated to determine grouping. + * @param aggregateExpressions expressions that are computed for each group. + * @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used. + * @param child the input data source. + */ +@DeveloperApi +case class SortMergeGeneratedAggregate( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + unsafeEnabled: Boolean, + sortEnabled: Boolean, + child: SparkPlan) + extends UnaryNode with GeneratedAggregate { + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingExpressions) :: Nil + + // this is to manually construct an ordering that can be used to compare keys + private val keyOrdering: RowOrdering = RowOrdering.forSchema(groupingExpressions.map(_.dataType)) + + override def outputOrdering: Seq[SortOrder] = requiredOrders(groupingExpressions) + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(groupingExpressions) :: Nil + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = + groupingExpressions.map(SortOrder(_, Ascending)) + + protected override def doExecute(): RDD[InternalRow] = { + + child.execute().mapPartitions { iter => + + // Builds a new custom class for holding the results of aggregation for a group. + val initialValues = computeFunctions.flatMap(_.initialValues) + val newAggregationBuffer = newProjection(initialValues, child.output) + log.info(s"Initial values: ${initialValues.mkString(",")}") + + // A projection that computes the group given an input tuple. + val groupProjection = newProjection(groupingExpressions, child.output) + log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}") + + // A projection that is used to update the aggregate values for a group given a new tuple. + // This projection should be targeted at the current values for the group and then applied + // to a joined row of the current values with the new input row. + val updateExpressions = computeFunctions.flatMap(_.update) + val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output + val updateProjection = newMutableProjection(updateExpressions, updateSchema)() + log.info(s"Update Expressions: ${updateExpressions.mkString(",")}") + + // A projection that produces the final result, given a computation. + val resultProjectionBuilder = + newMutableProjection( + resultExpressions, + namedGroups.map(_._2) ++ computationSchema) + log.info(s"Result Projection: ${resultExpressions.mkString(",")}") + + val joinedRow = new JoinedRow3 + + new Iterator[InternalRow] { + + private[this] var currentElement: InternalRow = _ + private[this] var nextElement: InternalRow = _ + private[this] var currentKey: InternalRow = _ + private[this] var nextKey: InternalRow = _ + private[this] var currentBuffer: MutableRow = _ + private[this] val resultProjection = resultProjectionBuilder() + + initialize() + + private def initialize() = { + if (iter.hasNext) { + currentElement = iter.next() + currentKey = groupProjection(currentElement) + } else { + currentElement = null + } + } + + override final def hasNext: Boolean = { + if (currentElement != null) { + currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] + // Target the projection at the current aggregation buffer and then project the updated + // values. + updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentElement)) + var stop: Boolean = false + while (!stop) { + if (iter.hasNext) { + nextElement = iter.next() + nextKey = groupProjection(nextElement) + stop = keyOrdering.compare(currentKey, nextKey) != 0 + if (!stop) { + updateProjection.target(currentBuffer)(joinedRow(currentBuffer, nextElement)) + } + } else { + nextElement = null + stop = true + } + } + true + } else { + false + } + } + + override final def next(): InternalRow = { + val currentGroup = currentKey + currentElement = nextElement + currentKey = nextKey + resultProjection(joinedRow(currentGroup, currentBuffer)) + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 422992d019c7..d3ea2beee38b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -121,6 +121,31 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Aggregations that can be performed in two phases, before and after the shuffle. + // Cases where all aggregates can be sortmerged and codegened. + case PartialAggregation( + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child) + if canBeCodeGened( + allAggregates(partialComputation) ++ + allAggregates(rewrittenAggregateExpressions)) && + codegenEnabled && + sqlContext.conf.sortMergeAggregateEnabled && + namedGroupingAttributes != Nil => + execution.SortMergeGeneratedAggregate( + namedGroupingAttributes, + rewrittenAggregateExpressions, + unsafeEnabled, + false, + execution.HashGeneratedAggregate( + partial = true, + groupingExpressions, + partialComputation, + unsafeEnabled, + planLater(child))) :: Nil + // Cases where all aggregates can be codegened. case PartialAggregation( namedGroupingAttributes, @@ -132,18 +157,36 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { allAggregates(partialComputation) ++ allAggregates(rewrittenAggregateExpressions)) && codegenEnabled => - execution.GeneratedAggregate( + execution.HashGeneratedAggregate( partial = false, namedGroupingAttributes, rewrittenAggregateExpressions, unsafeEnabled, - execution.GeneratedAggregate( + execution.HashGeneratedAggregate( partial = true, groupingExpressions, partialComputation, unsafeEnabled, planLater(child))) :: Nil + // Cases where some aggregate can not be codegened but can be sortmerged + case PartialAggregation( + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child) + if sqlContext.conf.sortMergeAggregateEnabled && + namedGroupingAttributes != Nil => + execution.SortMergeAggregate( + namedGroupingAttributes, + rewrittenAggregateExpressions, + execution.Aggregate( + partial = true, + groupingExpressions, + partialComputation, + planLater(child))) :: Nil + // Cases where some aggregate can not be codegened case PartialAggregation( namedGroupingAttributes, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a47cc30e92e2..83027357affc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException -import org.apache.spark.sql.execution.GeneratedAggregate +import org.apache.spark.sql.execution._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.SQLTestUtils @@ -192,7 +192,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // First, check if we have GeneratedAggregate. var hasGeneratedAgg = false df.queryExecution.executedPlan.foreach { - case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true + case generatedAgg: HashGeneratedAggregate => hasGeneratedAgg = true case _ => } if (!hasGeneratedAgg) { @@ -281,6 +281,161 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } } + test("sortmerge aggregation with codegen") { + val originalValue = sqlContext.conf.codegenEnabled + val originalValue2 = sqlContext.conf.sortMergeAggregateEnabled + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + sqlContext.setConf(SQLConf.SORTMERGE_AGGREGATE, "true") + // Prepare a table that we can group some rows. + sqlContext.table("testData") + .unionAll(sqlContext.table("testData")) + .unionAll(sqlContext.table("testData")) + .registerTempTable("testData3x") + + def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { + val df = sql(sqlText) + // First, check if we have SortMergeGeneratedAggregate. + var hasGeneratedAgg = false + df.queryExecution.executedPlan.foreach { + case generatedAgg: SortMergeGeneratedAggregate => hasGeneratedAgg = true + case _ => + } + if (!hasGeneratedAgg) { + fail( + s""" + |Codegen is enabled, but query $sqlText does not have SortMergeGeneratedAggregate + |in the plan. + |${df.queryExecution.simpleString} + """.stripMargin) + } + // Then, check results. + checkAnswer(df, expectedResults) + } + + try { + // COUNT + testCodeGen( + "SELECT key, count(value) FROM testData3x GROUP BY key", + (1 to 100).map(i => Row(i, 3))) + // COUNT DISTINCT ON int + testCodeGen( + "SELECT value, count(distinct key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 1))) + // SUM + testCodeGen( + "SELECT value, sum(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 3 * i))) + // AVERAGE + testCodeGen( + "SELECT value, avg(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + // MAX + testCodeGen( + "SELECT value, max(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + // MIN + testCodeGen( + "SELECT value, min(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + // Some combinations. + testCodeGen( + """ + |SELECT + | value, + | sum(key), + | max(key), + | min(key), + | avg(key), + | count(key), + | count(distinct key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) + } finally { + sqlContext.dropTempTable("testData3x") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + sqlContext.setConf(SQLConf.SORTMERGE_AGGREGATE, originalValue2.toString) + } + } + + test("sortmerge aggregation without codegen") { + val originalValue = sqlContext.conf.codegenEnabled + val originalValue2 = sqlContext.conf.sortMergeAggregateEnabled + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "false") + sqlContext.setConf(SQLConf.SORTMERGE_AGGREGATE, "true") + // Prepare a table that we can group some rows. + sqlContext.table("testData") + .unionAll(sqlContext.table("testData")) + .unionAll(sqlContext.table("testData")) + .registerTempTable("testData3x") + + def testSortMergeAggreate(sqlText: String, expectedResults: Seq[Row]): Unit = { + val df = sql(sqlText) + // First, check if we have SortMergeAggregate. + var hasGeneratedAgg = false + df.queryExecution.executedPlan.foreach { + case generatedAgg: SortMergeAggregate => hasGeneratedAgg = true + case _ => + } + if (!hasGeneratedAgg) { + fail( + s""" + |Codegen is unabled, but query $sqlText does not have SortMergeAggregate in the plan. + |${df.queryExecution.simpleString} + """.stripMargin) + } + // Then, check results. + checkAnswer(df, expectedResults) + } + + try { + // COUNT + testSortMergeAggreate( + "SELECT key, count(value) FROM testData3x GROUP BY key", + (1 to 100).map(i => Row(i, 3))) + // COUNT DISTINCT ON int + testSortMergeAggreate( + "SELECT value, count(distinct key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 1))) + // SUM + testSortMergeAggreate( + "SELECT value, sum(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 3 * i))) + // AVERAGE + testSortMergeAggreate( + "SELECT value, avg(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + // MAX + testSortMergeAggreate( + "SELECT value, max(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + // MIN + testSortMergeAggreate( + "SELECT value, min(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + // Some combinations. + testSortMergeAggreate( + """ + |SELECT + | value, + | sum(key), + | max(key), + | min(key), + | avg(key), + | count(key), + | count(distinct key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) + } finally { + sqlContext.dropTempTable("testData3x") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + sqlContext.setConf(SQLConf.SORTMERGE_AGGREGATE, originalValue2.toString) + } + } + test("Add Parser of SQL COALESCE()") { checkAnswer( sql("""SELECT COALESCE(1, 2)"""),