diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 7f7921d56f49..e193ed222e22 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -278,4 +278,13 @@ package object config {
"spark.io.compression.codec.")
.booleanConf
.createWithDefault(false)
+
+ private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD =
+ ConfigBuilder("spark.shuffle.accurateBlockThreshold")
+ .doc("When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will " +
+ "record the size accurately if it's above this config. This helps to prevent OOM by " +
+ "avoiding underestimating shuffle block size when fetch shuffle blocks.")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(100 * 1024 * 1024)
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index b2e9a97129f0..048e0d018659 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -19,8 +19,13 @@ package org.apache.spark.scheduler
import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
import org.roaringbitmap.RoaringBitmap
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.config
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils
@@ -121,34 +126,41 @@ private[spark] class CompressedMapStatus(
}
/**
- * A [[MapStatus]] implementation that only stores the average size of non-empty blocks,
+ * A [[MapStatus]] implementation that stores the accurate size of huge blocks, which are larger
+ * than spark.shuffle.accurateBlockThreshold. It stores the average size of other non-empty blocks,
* plus a bitmap for tracking which blocks are empty.
*
* @param loc location where the task is being executed
* @param numNonEmptyBlocks the number of non-empty blocks
* @param emptyBlocks a bitmap tracking which blocks are empty
- * @param avgSize average size of the non-empty blocks
+ * @param avgSize average size of the non-empty and non-huge blocks
+ * @param hugeBlockSizes sizes of huge blocks by their reduceId.
*/
private[spark] class HighlyCompressedMapStatus private (
private[this] var loc: BlockManagerId,
private[this] var numNonEmptyBlocks: Int,
private[this] var emptyBlocks: RoaringBitmap,
- private[this] var avgSize: Long)
+ private[this] var avgSize: Long,
+ @transient private var hugeBlockSizes: Map[Int, Byte])
extends MapStatus with Externalizable {
// loc could be null when the default constructor is called during deserialization
- require(loc == null || avgSize > 0 || numNonEmptyBlocks == 0,
+ require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0,
"Average size can only be zero for map stages that produced no output")
- protected def this() = this(null, -1, null, -1) // For deserialization only
+ protected def this() = this(null, -1, null, -1, null) // For deserialization only
override def location: BlockManagerId = loc
override def getSizeForBlock(reduceId: Int): Long = {
+ assert(hugeBlockSizes != null)
if (emptyBlocks.contains(reduceId)) {
0
} else {
- avgSize
+ hugeBlockSizes.get(reduceId) match {
+ case Some(size) => MapStatus.decompressSize(size)
+ case None => avgSize
+ }
}
}
@@ -156,6 +168,11 @@ private[spark] class HighlyCompressedMapStatus private (
loc.writeExternal(out)
emptyBlocks.writeExternal(out)
out.writeLong(avgSize)
+ out.writeInt(hugeBlockSizes.size)
+ hugeBlockSizes.foreach { kv =>
+ out.writeInt(kv._1)
+ out.writeByte(kv._2)
+ }
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
@@ -163,6 +180,14 @@ private[spark] class HighlyCompressedMapStatus private (
emptyBlocks = new RoaringBitmap()
emptyBlocks.readExternal(in)
avgSize = in.readLong()
+ val count = in.readInt()
+ val hugeBlockSizesArray = mutable.ArrayBuffer[Tuple2[Int, Byte]]()
+ (0 until count).foreach { _ =>
+ val block = in.readInt()
+ val size = in.readByte()
+ hugeBlockSizesArray += Tuple2(block, size)
+ }
+ hugeBlockSizes = hugeBlockSizesArray.toMap
}
}
@@ -178,11 +203,21 @@ private[spark] object HighlyCompressedMapStatus {
// we expect that there will be far fewer of them, so we will perform fewer bitmap insertions.
val emptyBlocks = new RoaringBitmap()
val totalNumBlocks = uncompressedSizes.length
+ val threshold = Option(SparkEnv.get)
+ .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD))
+ .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get)
+ val hugeBlockSizesArray = ArrayBuffer[Tuple2[Int, Byte]]()
while (i < totalNumBlocks) {
- var size = uncompressedSizes(i)
+ val size = uncompressedSizes(i)
if (size > 0) {
numNonEmptyBlocks += 1
- totalSize += size
+ // Huge blocks are not included in the calculation for average size, thus size for smaller
+ // blocks is more accurate.
+ if (size < threshold) {
+ totalSize += size
+ } else {
+ hugeBlockSizesArray += Tuple2(i, MapStatus.compressSize(uncompressedSizes(i)))
+ }
} else {
emptyBlocks.add(i)
}
@@ -195,6 +230,7 @@ private[spark] object HighlyCompressedMapStatus {
}
emptyBlocks.trim()
emptyBlocks.runOptimize()
- new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize)
+ new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
+ hugeBlockSizesArray.toMap)
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
index 759d52fca5ce..3ec37f674c77 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
@@ -17,11 +17,15 @@
package org.apache.spark.scheduler
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
+
import scala.util.Random
+import org.mockito.Mockito._
import org.roaringbitmap.RoaringBitmap
-import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite}
+import org.apache.spark.internal.config
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.BlockManagerId
@@ -128,4 +132,26 @@ class MapStatusSuite extends SparkFunSuite {
assert(size1 === size2)
assert(!success)
}
+
+ test("Blocks which are bigger than SHUFFLE_ACCURATE_BLOCK_THRESHOLD should not be " +
+ "underestimated.") {
+ val conf = new SparkConf().set(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, "1000")
+ val env = mock(classOf[SparkEnv])
+ doReturn(conf).when(env).conf
+ SparkEnv.set(env)
+ // Value of element in sizes is equal to the corresponding index.
+ val sizes = (0L to 2000L).toArray
+ val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes)
+ val arrayStream = new ByteArrayOutputStream(102400)
+ val objectOutputStream = new ObjectOutputStream(arrayStream)
+ assert(status1.isInstanceOf[HighlyCompressedMapStatus])
+ objectOutputStream.writeObject(status1)
+ objectOutputStream.flush()
+ val array = arrayStream.toByteArray
+ val objectInput = new ObjectInputStream(new ByteArrayInputStream(array))
+ val status2 = objectInput.readObject().asInstanceOf[HighlyCompressedMapStatus]
+ (1001 to 2000).foreach {
+ case part => assert(status2.getSizeForBlock(part) >= sizes(part))
+ }
+ }
}
diff --git a/docs/configuration.md b/docs/configuration.md
index 1d8d963016c7..a6b6d5dfa5f9 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -612,6 +612,15 @@ Apart from these, the following properties are also available, and may be useful
spark.io.compression.codec.
+
spark.shuffle.accurateBlockThresholdspark.io.encryption.enabled