Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ private[spark] object SQLConf {
doc = "When true, use the optimized Tungsten physical execution backend which explicitly " +
"manages memory and dynamically generates bytecode for expression evaluation.")

val PRE_AGGREGATION_ENABLED = booleanConf("spark.sql.preAggregation.enabled",
defaultValue = Some(true),
doc = "When true, use pre-aggregation in TungstenAggregationIterator to continue to " +
"perform hash-based pre-aggregation after we've decided to spill and switch to " +
"sort-based aggregation. If a very low reduction factor is expected for the data " +
"this feature shoule be disabled to obtain better performance.")

val CODEGEN_ENABLED = booleanConf("spark.sql.codegen",
defaultValue = Some(true), // use TUNGSTEN_ENABLED as default
doc = "When true, code will be dynamically generated at runtime for expression evaluation in" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.types.StructType

case class TungstenAggregate(
Expand Down Expand Up @@ -78,6 +79,8 @@ case class TungstenAggregate(
}
}

private val preAggregation: Boolean = sqlContext.getConf(SQLConf.PRE_AGGREGATION_ENABLED)

protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numInputRows = longMetric("numInputRows")
val numOutputRows = longMetric("numOutputRows")
Expand All @@ -100,6 +103,7 @@ case class TungstenAggregate(
newMutableProjection,
child.output,
testFallbackStartsAt,
preAggregation,
numInputRows,
numOutputRows,
dataSize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class TungstenAggregationIterator(
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
originalInputAttributes: Seq[Attribute],
testFallbackStartsAt: Option[Int],
preAggregation: Boolean,
numInputRows: LongSQLMetric,
numOutputRows: LongSQLMetric,
dataSize: LongSQLMetric,
Expand Down Expand Up @@ -473,28 +474,31 @@ class TungstenAggregationIterator(
// Part 3: Methods and fields used by hash-based aggregation.
///////////////////////////////////////////////////////////////////////////

private[this] def initHashMap(): UnsafeFixedWidthAggregationMap = {
new UnsafeFixedWidthAggregationMap(
initialAggregationBuffer,
StructType.fromAttributes(allAggregateFunctions.flatMap(_.aggBufferAttributes)),
StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
TaskContext.get.taskMemoryManager(),
1024 * 16, // initial capacity
TaskContext.get().taskMemoryManager().pageSizeBytes,
false // disable tracking of performance metrics
)
}

// This is the hash map used for hash-based aggregation. It is backed by an
// UnsafeFixedWidthAggregationMap and it is used to store
// all groups and their corresponding aggregation buffers for hash-based aggregation.
private[this] val hashMap = new UnsafeFixedWidthAggregationMap(
initialAggregationBuffer,
StructType.fromAttributes(allAggregateFunctions.flatMap(_.aggBufferAttributes)),
StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
TaskContext.get().taskMemoryManager(),
1024 * 16, // initial capacity
TaskContext.get().taskMemoryManager().pageSizeBytes,
false // disable tracking of performance metrics
)
private[this] var hashMap = initHashMap()

// Exposed for testing
private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap

// The function used to read and process input rows. When processing input rows,
// it first uses hash-based aggregation by putting groups and their buffers in
// hashMap. If we could not allocate more memory for the map, we switch to
// sort-based aggregation (by calling switchToSortBasedAggregation).
private def processInputs(): Unit = {
assert(inputIter != null, "attempted to process input when iterator was null")
// Process input rows using a given UnsafeFixedWidthAggregationMap.
// It returns None if all input rows are processed with the given hash map, and
// returns (groupingKey, current input) if we can't allocate more memory for the hash map.
private def internalProcessInputs(
hashMap: UnsafeFixedWidthAggregationMap): Option[(UnsafeRow, InternalRow)] = {
if (groupingExpressions.isEmpty) {
// If there is no grouping expressions, we can just reuse the same buffer over and over again.
// Note that it would be better to eliminate the hash map entirely in the future.
Expand All @@ -512,14 +516,25 @@ class TungstenAggregationIterator(
val groupingKey = groupProjection.apply(newInput)
val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
if (buffer == null) {
// buffer == null means that we could not allocate more memory.
// Now, we need to spill the map and switch to sort-based aggregation.
switchToSortBasedAggregation(groupingKey, newInput)
return Some((groupingKey, newInput))
} else {
processRow(buffer, newInput)
}
}
}
None
}

// The function used to read and process input rows. When processing input rows,
// it first uses hash-based aggregation by putting groups and their buffers in
// hashMap. If we could not allocate more memory for the map, we switch to
// sort-based aggregation (by calling switchToSortBasedAggregation).
private def processInputs(): Unit = {
assert(inputIter != null, "attempted to process input when iterator was null")
val ret = internalProcessInputs(hashMap)
if (ret.isDefined) {
switchToSortBasedAggregation(ret.get._1, ret.get._2)
}
}

// This function is only used for testing. It basically the same as processInputs except
Expand Down Expand Up @@ -604,12 +619,36 @@ class TungstenAggregationIterator(

// Process the rest of input rows.
while (inputIter.hasNext) {
val newInput = inputIter.next()
numInputRows += 1
val groupingKey = groupProjection.apply(newInput)
buffer.copyFrom(initialAggregationBuffer)
processRow(buffer, newInput)
externalSorter.insertKV(groupingKey, buffer)
if (preAggregation) {
hashMap = initHashMap()
val ret = internalProcessInputs(hashMap)
if (ret.isDefined) {
// If we can't allocate more memory, we insert all records from the hashmap
// into externalSorter.
val iter = hashMap.iterator()
while(iter.next()) {
externalSorter.insertKV(iter.getKey(), iter.getValue())
}
hashMap.free()

buffer.copyFrom(initialAggregationBuffer)
processRow(buffer, ret.get._2)
externalSorter.insertKV(ret.get._1, buffer)
} else {
val iter = hashMap.iterator()
while(iter.next()) {
externalSorter.insertKV(iter.getKey(), iter.getValue())
}
hashMap.free()
}
} else {
val newInput = inputIter.next()
numInputRows += 1
val groupingKey = groupProjection.apply(newInput)
buffer.copyFrom(initialAggregationBuffer)
processRow(buffer, newInput)
externalSorter.insertKV(groupingKey, buffer)
}
}
} else {
// When needsProcess is false, the format of input rows is groupingKey + aggregation buffer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte
}
val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy")
iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty,
0, Seq.empty, newMutableProjection, Seq.empty, None,
0, Seq.empty, newMutableProjection, Seq.empty, None, false,
dummyAccum, dummyAccum, dummyAccum, dummyAccum)
val numPages = iter.getHashMap.getNumDataPages
assert(numPages === 1)
Expand Down