Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,22 @@ import scala.collection.JavaConverters._
trait SQLConf {

/** Number of partitions to use for shuffle operators. */
private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt
private[sql] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt

/**
* Hash aggregation will be turned off if the ratio between hash table size and input rows
* is bigger than this number. Set to 1 to make sure hash aggregation is never turned off.
* Hive setting: hive.map.aggr.hash.min.reduction
*/
private[sql] def partialAggMinReduction: Double =
get("spark.sql.partialAgg.min.reduction", "0.5").toDouble

/**
* Number of rows to process before checking for [[partialAggMinReduction]].
* Hive setting: hive.groupby.mapaggr.checkinterval
*/
private[sql] def partialAggCheckInterval: Int =
get("spark.sql.partialAgg.check.interval", "10000").toInt

@transient
private val settings = java.util.Collections.synchronizedMap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

package org.apache.spark.sql.execution

import java.util.HashMap
import java.util.{HashMap => JHashMap}

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
Expand Down Expand Up @@ -77,18 +77,20 @@ case class Aggregate(
resultAttribute: AttributeReference)

/** A list of aggregates that need to be computed for each group. */
private[this] val computedAggregates = aggregateExpressions.flatMap { agg =>
agg.collect {
case a: AggregateExpression =>
ComputedAggregate(
a,
BindReferences.bindReference(a, childOutput).asInstanceOf[AggregateExpression],
AttributeReference(s"aggResult:$a", a.dataType, nullable = true)())
}
}.toArray
private[this] val computedAggregates: Array[ComputedAggregate] =
aggregateExpressions.flatMap { agg =>
agg.collect {
case a: AggregateExpression =>
ComputedAggregate(
a,
BindReferences.bindReference(a, childOutput),
AttributeReference(s"aggResult:$a", a.dataType, nullable = true)())
}
}.toArray

/** The schema of the result of all aggregate evaluations */
private[this] val computedSchema = computedAggregates.map(_.resultAttribute)
private[this] val computedSchema: Array[AttributeReference] =
computedAggregates.map(_.resultAttribute)

/** Creates a new aggregate buffer for a group. */
private[this] def newAggregateBuffer(): Array[AggregateFunction] = {
Expand All @@ -102,7 +104,7 @@ case class Aggregate(
}

/** Named attributes used to substitute grouping attributes into the final result. */
private[this] val namedGroups = groupingExpressions.map {
private[this] val namedGroups: Seq[(Expression, Attribute)] = groupingExpressions.map {
case ne: NamedExpression => ne -> ne.toAttribute
case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute
}
Expand All @@ -111,21 +113,22 @@ case class Aggregate(
* A map of substitutions that are used to insert the aggregate expressions and grouping
* expression into the final result expression.
*/
private[this] val resultMap =
private[this] val resultMap: Map[Expression, Attribute] =
(computedAggregates.map { agg => agg.unbound -> agg.resultAttribute } ++ namedGroups).toMap

/**
* Substituted version of aggregateExpressions expressions which are used to compute final
* output rows given a group and the result of all aggregate computations.
*/
private[this] val resultExpressions = aggregateExpressions.map { agg =>
private[this] val resultExpressions: Seq[Expression] = aggregateExpressions.map { agg =>
agg.transform {
case e: Expression if resultMap.contains(e) => resultMap(e)
}
}

override def execute() = attachTree(this, "execute") {
if (groupingExpressions.isEmpty) {
// No grouping key, i.e. the output will contain only one row.
child.execute().mapPartitions { iter =>
val buffer = newAggregateBuffer()
var currentRow: Row = null
Expand All @@ -149,12 +152,15 @@ case class Aggregate(
Iterator(resultProjection(aggregateResults))
}
} else {
// With grouping key.
child.execute().mapPartitions { iter =>
val hashTable = new HashMap[Row, Array[AggregateFunction]]
val hashTable = new JHashMap[Row, Array[AggregateFunction]]
val groupingProjection = new MutableProjection(groupingExpressions, childOutput)

var partialAggEnabled = true
var rowCount: Int = 0
var currentRow: Row = null
while (iter.hasNext) {
while (iter.hasNext && partialAggEnabled) {
currentRow = iter.next()
val currentGroup = groupingProjection(currentRow)
var currentBuffer = hashTable.get(currentGroup)
Expand All @@ -168,9 +174,22 @@ case class Aggregate(
currentBuffer(i).update(currentRow)
i += 1
}

// Disable partial hash-based aggregation if desired minimum reduction is
// not observed after initial interval.
rowCount += 1
if (rowCount == 100 & partial) {
val hashTableSize = hashTable.size
logger.info(s"#hash table=$hashTableSize #rows=$rowCount " +
s"reduction=${hashTableSize.toFloat/rowCount} minReduction=0.5")
if (hashTableSize > rowCount * 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Man, those are some high standards!

logger.info("Partial aggregation disabled")
partialAggEnabled = false
}
}
}

new Iterator[Row] {
val hashTableIter = new Iterator[Row] {
private[this] val hashTableIter = hashTable.entrySet().iterator()
private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length)
private[this] val resultProjection =
Expand All @@ -193,6 +212,27 @@ case class Aggregate(
}
resultProjection(joinedRow(aggregateResults, currentGroup))
}
} // end of hashTableIter

if (!partialAggEnabled && iter.hasNext) {
// Partial aggregation disabled and not all entries have been consumed by the hash table.
val aggregateResults = new GenericMutableRow(computedAggregates.length)
val resultProjection =
new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2))
val joinedRow = new JoinedRow
hashTableIter ++ iter.map { currentRow =>
val currentGroup = groupingProjection(currentRow)
val currentBuffer = newAggregateBuffer()
var i = 0
while (i < currentBuffer.length) {
currentBuffer(i).update(currentRow)
aggregateResults(i) = currentBuffer(i).eval(EmptyRow)
i += 1
}
resultProjection(joinedRow(aggregateResults, currentGroup))
}
} else {
hashTableIter
}
}
}
Expand Down