diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 59fdf659c9e11..90594ac4de3d2 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -34,12 +34,37 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - private val externalSorting = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) + private[this] var externalSorting = true + private[this] var partialAggCheckInterval = 10000 + private[this] var partialAggReduction = 0.5 + + private[spark] def withConf(conf: SparkConf): this.type = { + externalSorting = conf.getBoolean("spark.shuffle.spill", defaultValue = true) + partialAggCheckInterval = conf.getInt("spark.partialAgg.interval", 10000) + partialAggReduction = conf.getDouble("spark.partialAgg.reduction", 0.5) + this + } + + // Load the configs from SparkEnv if SparkEnv is set (it wouldn't be set in unit tests). + if (SparkEnv.get != null) { + withConf(SparkEnv.get.conf) + } @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]): Iterator[(K, C)] = combineValuesByKey(iter, null) + /** + * Combines values using a (potentially external) hash map and return the combined results, aka + * partial aggregation. + * + * Note that output from this function does not guarantee each key appearing only once. It can + * choose to not combine values if it doesn't observe any reduction in size with partial + * aggregation. In the default case, it will go through the first 10000 records and perform + * partial aggregation. After the first 10000 records, if it doesn't see a reduction factor + * smaller than 0.5, it will disable partial aggregation and simply output one row per input row + * for records beyond the first 10000. + */ def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], context: TaskContext): Iterator[(K, C)] = { if (!externalSorting) { @@ -48,12 +73,33 @@ case class Aggregator[K, V, C] ( val update = (hadValue: Boolean, oldValue: C) => { if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) } - while (iter.hasNext) { + + // A flag indicating whether we should do partial aggregation or not. + var partialAggEnabled = true + var numRecords = 0 + while (iter.hasNext && partialAggEnabled) { kv = iter.next() combiners.changeValue(kv._1, update) + + numRecords += 1 + if (numRecords == partialAggCheckInterval) { + // Disable partial aggregation if we don't see enough reduction. + val partialAggSize = combiners.size + if (partialAggSize > numRecords * partialAggReduction) { + partialAggEnabled = false + } + } + } + + if (!partialAggEnabled && iter.hasNext) { + // Partial aggregation was turned off because we didn't observe enough reduction. + combiners.iterator ++ iter.map { kv => (kv._1, createCombiner(kv._2)) } + } else { + // We consumed all our records in partial aggregation. Just iterate over the results. + combiners.iterator } - combiners.iterator } else { + // TODO: disable partial aggregation when reduction factor is not met. val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) while (iter.hasNext) { val (k, v) = iter.next() diff --git a/core/src/test/scala/org/apache/spark/AggregatorSuite.scala b/core/src/test/scala/org/apache/spark/AggregatorSuite.scala new file mode 100644 index 0000000000000..e91bde634d395 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/AggregatorSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +class AggregatorSuite extends FunSuite { + + private val testData = Seq(("k1", 1), ("k2", 1), ("k3", 1), ("k4", 1), ("k1", 1)) + + test("combineValuesByKey with partial aggregation") { + val agg = new Aggregator[String, Int, Int](v => v, (c, v) => c + v, _ + _).withConf( + new SparkConf().set("spark.shuffle.spill", "false")) + val output = agg.combineValuesByKey(testData.iterator, null).toMap + assert(output("k1") === 2) + assert(output("k2") === 1) + assert(output("k3") === 1) + assert(output("k4") === 1) + } + + test("combineValuesByKey disabling partial aggregation") { + val agg = new Aggregator[String, Int, Int](v => v, (c, v) => c + v, _ + _).withConf( + new SparkConf().set("spark.shuffle.spill", "false") + .set("spark.partialAgg.interval", "2") + .set("spark.partialAgg.reduction", "0.5")) + + val output = agg.combineValuesByKey(testData.iterator, null).toSeq + assert(output.count(record => record == ("k1", 1)) === 2) + assert(output.count(record => record == ("k2", 1)) === 1) + assert(output.count(record => record == ("k3", 1)) === 1) + assert(output.count(record => record == ("k4", 1)) === 1) + } + + test("partial aggregation check interval") { + val testDataWithPartial = Seq(("k1", 1), ("k1", 1), ("k2", 1)) + val testDataWithoutPartial = Seq(("k1", 1), ("k2", 1), ("k1", 1)) + + val agg = new Aggregator[String, Int, Int](v => v, (c, v) => c + v, _ + _).withConf( + new SparkConf().set("spark.shuffle.spill", "false") + .set("spark.partialAgg.interval", "2") + .set("spark.partialAgg.reduction", "0.5")) + + assert(agg.combineValuesByKey(testDataWithPartial.iterator, null).size === 2) + assert(agg.combineValuesByKey(testDataWithoutPartial.iterator, null).size === 3) + } + +}