Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 9 additions & 21 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
// For testing
def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1, useOldFetchProtocol = false)
getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
}

/**
Expand All @@ -334,8 +334,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
def getMapSizesByExecutorId(
shuffleId: Int,
startPartition: Int,
endPartition: Int,
useOldFetchProtocol: Boolean)
endPartition: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]

/**
Expand Down Expand Up @@ -688,15 +687,14 @@ private[spark] class MapOutputTrackerMaster(
def getMapSizesByExecutorId(
shuffleId: Int,
startPartition: Int,
endPartition: Int,
useOldFetchProtocol: Boolean)
endPartition: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
shuffleStatuses.get(shuffleId) match {
case Some (shuffleStatus) =>
shuffleStatus.withMapStatuses { statuses =>
MapOutputTracker.convertMapStatuses(
shuffleId, startPartition, endPartition, statuses, useOldFetchProtocol)
shuffleId, startPartition, endPartition, statuses)
}
case None =>
Iterator.empty
Expand Down Expand Up @@ -733,14 +731,13 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
override def getMapSizesByExecutorId(
shuffleId: Int,
startPartition: Int,
endPartition: Int,
useOldFetchProtocol: Boolean)
endPartition: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
val statuses = getStatuses(shuffleId)
try {
MapOutputTracker.convertMapStatuses(
shuffleId, startPartition, endPartition, statuses, useOldFetchProtocol)
shuffleId, startPartition, endPartition, statuses)
} catch {
case e: MetadataFetchFailedException =>
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
Expand Down Expand Up @@ -883,7 +880,6 @@ private[spark] object MapOutputTracker extends Logging {
* @param startPartition Start of map output partition ID range (included in range)
* @param endPartition End of map output partition ID range (excluded from range)
* @param statuses List of map statuses, indexed by map partition index.
* @param useOldFetchProtocol Whether to use the old shuffle fetch protocol.
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
* and the second item is a sequence of (shuffle block id, shuffle block size, map index)
* tuples describing the shuffle blocks that are stored at that block manager.
Expand All @@ -892,8 +888,7 @@ private[spark] object MapOutputTracker extends Logging {
shuffleId: Int,
startPartition: Int,
endPartition: Int,
statuses: Array[MapStatus],
useOldFetchProtocol: Boolean): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
assert (statuses != null)
val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
for ((status, mapIndex) <- statuses.iterator.zipWithIndex) {
Expand All @@ -905,15 +900,8 @@ private[spark] object MapOutputTracker extends Logging {
for (part <- startPartition until endPartition) {
val size = status.getSizeForBlock(part)
if (size != 0) {
if (useOldFetchProtocol) {
Copy link
Member

@xuanyuanking xuanyuanking Oct 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the report and fix!
The root cause is while we set useOldFetchProtocol=true here, the shuffle id in the reader side and the writer side are inconsistent.
But we can't fix like this, because while useOldFetchProtocl=false, we'll use the old version of fetching protocol OpenBlocks, which consider map id is Integer and will directly parse the string. So for the big and long-running application, it will still not work. See the code:

mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]);

So the right way I think is doing the fix in ShuffleWriteProcessor, we should fill mapId with mapTaskId or mapIndex denpending on config spark.shuffle.useOldFetchProtocol.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, could you explain why Integer directly parse the string for the big and long-running application not work?
Is it a performance problem?
looking forward for your reply.

// While we use the old shuffle fetch protocol, we use mapIndex as mapId in the
// ShuffleBlockId.
splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
((ShuffleBlockId(shuffleId, mapIndex, part), size, mapIndex))
} else {
splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
((ShuffleBlockId(shuffleId, status.mapTaskId, part), size, mapIndex))
}
splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
((ShuffleBlockId(shuffleId, status.mapId, part), size, mapIndex))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ private[spark] sealed trait MapStatus {
def getSizeForBlock(reduceId: Int): Long

/**
* The unique ID of this shuffle map task, we use taskContext.taskAttemptId to fill this.
* The unique ID of this shuffle map task, if spark.shuffle.useOldFetchProtocol enabled we use
* partitionId of the task or taskContext.taskAttemptId is used.
*/
def mapTaskId: Long
def mapId: Long
}


Expand Down Expand Up @@ -129,7 +130,7 @@ private[spark] class CompressedMapStatus(
MapStatus.decompressSize(compressedSizes(reduceId))
}

override def mapTaskId: Long = _mapTaskId
override def mapId: Long = _mapTaskId

override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
Expand Down Expand Up @@ -189,7 +190,7 @@ private[spark] class HighlyCompressedMapStatus private (
}
}

override def mapTaskId: Long = _mapTaskId
override def mapId: Long = _mapTaskId

override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.Properties

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.rdd.RDD

/**
Expand Down Expand Up @@ -91,7 +91,12 @@ private[spark] class ShuffleMapTask(

val rdd = rddAndDep._1
val dep = rddAndDep._2
dep.shuffleWriterProcessor.write(rdd, dep, context, partition)
// While we use the old shuffle fetch protocol, we use partitionId as mapId in the
// ShuffleBlockId construction.
val mapId = if (SparkEnv.get.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) {
partitionId
} else context.taskAttemptId()
dep.shuffleWriterProcessor.write(rdd, dep, mapId, context, partition)
}

override def preferredLocations: Seq[TaskLocation] = preferredLocs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
context,
blockManager.blockStoreClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition,
SparkEnv.get.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)),
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the shuffle read side and we need to know the value of SHUFFLE_USE_OLD_FETCH_PROTOCOL. I think the bug is in the shuffle write side which is fixed in this PR. Do we really need to change the shuffle read side?

Copy link
Contributor Author

@sandeep-katta sandeep-katta Oct 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant code, since ShuffleWrite writes the mapId based on the spark.shuffle.useOldFetchProtocol flag, MapStatus.mapTaskId always gives the mapId which is set by the ShuffleWriter

serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ private[spark] class ShuffleWriteProcessor extends Serializable with Logging {
def write(
rdd: RDD[_],
dep: ShuffleDependency[_, _, _],
mapId: Long,
context: TaskContext,
partition: Partition): MapStatus = {
var writer: ShuffleWriter[Any, Any] = null
try {
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](
dep.shuffleHandle,
context.taskAttemptId(),
mapId,
context,
createMetricsReporter(context))
writer.write(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(size10000, size0, size1000, size0), 6))
assert(tracker.containsShuffle(10))
assert(tracker.getMapSizesByExecutorId(10, 0, 4, false).toSeq ===
assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq ===
Seq(
(BlockManagerId("a", "hostA", 1000),
Seq((ShuffleBlockId(10, 5, 1), size1000, 0),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

import org.scalatest.BeforeAndAfterAll

class ShuffleOldFetchProtocolSuite extends ShuffleSuite with BeforeAndAfterAll {

// This test suite should run all tests by setting spark.shuffle.useOldFetchProtocol=true.
override def beforeAll(): Unit = {
super.beforeAll()
conf.set("spark.shuffle.useOldFetchProtocol", "true")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -507,14 +507,14 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assert(initialMapStatus1.count(_ != null) === 3)
assert(initialMapStatus1.map{_.location.executorId}.toSet ===
Set("exec-hostA1", "exec-hostA2", "exec-hostB"))
assert(initialMapStatus1.map{_.mapTaskId}.toSet === Set(5, 6, 7))
assert(initialMapStatus1.map{_.mapId}.toSet === Set(5, 6, 7))

val initialMapStatus2 = mapOutputTracker.shuffleStatuses(secondShuffleId).mapStatuses
// val initialMapStatus1 = mapOutputTracker.mapStatuses.get(0).get
assert(initialMapStatus2.count(_ != null) === 3)
assert(initialMapStatus2.map{_.location.executorId}.toSet ===
Set("exec-hostA1", "exec-hostA2", "exec-hostB"))
assert(initialMapStatus2.map{_.mapTaskId}.toSet === Set(8, 9, 10))
assert(initialMapStatus2.map{_.mapId}.toSet === Set(8, 9, 10))

// reduce stage fails with a fetch failure from one host
complete(taskSets(2), Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class MapStatusSuite extends SparkFunSuite {
val status1 = compressAndDecompressMapStatus(status)
assert(status1.isInstanceOf[HighlyCompressedMapStatus])
assert(status1.location == loc)
assert(status1.mapTaskId == mapTaskAttemptId)
assert(status1.mapId == mapTaskAttemptId)
for (i <- 0 until 3000) {
val estimate = status1.getSizeForBlock(i)
if (sizes(i) > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
// shuffle data to read.
val mapOutputTracker = mock(classOf[MapOutputTracker])
when(mapOutputTracker.getMapSizesByExecutorId(
shuffleId, reduceId, reduceId + 1, useOldFetchProtocol = false)).thenReturn {
shuffleId, reduceId, reduceId + 1)).thenReturn {
// Test a scenario where all data is local, to avoid creating a bunch of additional mocks
// for the code to read data over the network.
val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId =>
Expand Down