Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,15 @@ object SparkEnv extends Logging {
conf.get(BLOCK_MANAGER_PORT)
}

val blockTransferService =
new NettyBlockTransferService(conf, securityManager, bindAddress, advertiseAddress,
blockManagerPort, numUsableCores)

val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint(
BlockManagerMaster.DRIVER_ENDPOINT_NAME,
new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)),
conf, isDriver)

val blockTransferService =
new NettyBlockTransferService(conf, securityManager, bindAddress, advertiseAddress,
blockManagerPort, numUsableCores, blockManagerMaster.driverEndpoint)

// NB: blockManager is not valid until initialize() is called later.
val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster,
serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager,
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkException.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,9 @@ private[spark] class SparkDriverExecutionException(cause: Throwable)
*/
private[spark] case class SparkUserAppException(exitCode: Int)
extends SparkException(s"User application exited with $exitCode")

/**
* Exception thrown when the relative executor to access is dead.
*/
private[spark] case class ExecutorDeadException(message: String)
extends SparkException(message)
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@

package org.apache.spark.network.netty

import java.io.IOException
import java.nio.ByteBuffer
import java.util.{HashMap => JHashMap, Map => JMap}

import scala.collection.JavaConverters._
import scala.concurrent.{Future, Promise}
import scala.reflect.ClassTag
import scala.util.{Success, Try}

import com.codahale.metrics.{Metric, MetricSet}

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.ExecutorDeadException
import org.apache.spark.internal.config
import org.apache.spark.network._
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
Expand All @@ -36,8 +39,10 @@ import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, OneForOneBlockFetcher, RetryingBlockFetcher}
import org.apache.spark.network.shuffle.protocol.{UploadBlock, UploadBlockStream}
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.storage.BlockManagerMessages.IsExecutorAlive
import org.apache.spark.util.Utils

/**
Expand All @@ -49,7 +54,8 @@ private[spark] class NettyBlockTransferService(
bindAddress: String,
override val hostName: String,
_port: Int,
numCores: Int)
numCores: Int,
driverEndPointRef: RpcEndpointRef = null)
extends BlockTransferService {

// TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
Expand Down Expand Up @@ -112,8 +118,20 @@ private[spark] class NettyBlockTransferService(
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
val client = clientFactory.createClient(host, port)
new OneForOneBlockFetcher(client, appId, execId, blockIds, listener,
transportConf, tempFileManager).start()
try {
new OneForOneBlockFetcher(client, appId, execId, blockIds, listener,
transportConf, tempFileManager).start()
} catch {
case e: IOException =>
Try {
driverEndPointRef.askSync[Boolean](IsExecutorAlive(execId))
} match {
case Success(v) if v == false =>
throw new ExecutorDeadException(s"The relative remote executor(Id: $execId)," +
" which maintains the block data to fetch is dead.")
case _ => throw e
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ class BlockManagerMasterEndpoint(
case GetBlockStatus(blockId, askSlaves) =>
context.reply(blockStatus(blockId, askSlaves))

case IsExecutorAlive(executorId) =>
context.reply(blockManagerIdByExecutor.contains(executorId))

case GetMatchingBlockIds(filter, askSlaves) =>
context.reply(getMatchingBlockIds(filter, askSlaves))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,6 @@ private[spark] object BlockManagerMessages {
case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster

case class HasCachedBlocks(executorId: String) extends ToBlockManagerMaster

case class IsExecutorAlive(executorId: String) extends ToBlockManagerMaster
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,21 @@

package org.apache.spark.network.netty

import java.io.IOException

import scala.concurrent.{ExecutionContext, Future}
import scala.reflect.ClassTag
import scala.util.Random

import org.mockito.Mockito.mock
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest._

import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.{ExecutorDeadException, SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.network.BlockDataManager
import org.apache.spark.network.client.{TransportClient, TransportClientFactory}
import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager}
import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcTimeout}

class NettyBlockTransferServiceSuite
extends SparkFunSuite
Expand Down Expand Up @@ -77,6 +85,48 @@ class NettyBlockTransferServiceSuite
verifyServicePort(expectedPort = service0.port + 1, actualPort = service1.port)
}

test("SPARK-27637: test fetch block with executor dead") {
implicit val exectionContext = ExecutionContext.global
val port = 17634 + Random.nextInt(10000)
logInfo("random port for test: " + port)

val driverEndpointRef = new RpcEndpointRef(new SparkConf()) {
override def address: RpcAddress = null
override def name: String = "test"
override def send(message: Any): Unit = {}
// This rpcEndPointRef always return false for unit test to touch ExecutorDeadException.
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
Future{false.asInstanceOf[T]}
}
}

val clientFactory = mock(classOf[TransportClientFactory])
val client = mock(classOf[TransportClient])
// This is used to touch an IOException during fetching block.
when(client.sendRpc(any(), any())).thenAnswer(_ => {throw new IOException()})
var createClientCount = 0
when(clientFactory.createClient(any(), any())).thenAnswer(_ => {
createClientCount += 1
client
})

val listener = mock(classOf[BlockFetchingListener])
var hitExecutorDeadException = false
when(listener.onBlockFetchFailure(any(), any(classOf[ExecutorDeadException])))
.thenAnswer(_ => {hitExecutorDeadException = true})

service0 = createService(port, driverEndpointRef)
val clientFactoryField = service0.getClass.getField(
"org$apache$spark$network$netty$NettyBlockTransferService$$clientFactory")
clientFactoryField.setAccessible(true)
clientFactoryField.set(service0, clientFactory)

service0.fetchBlocks("localhost", port, "exec1",
Array("block1"), listener, mock(classOf[DownloadFileManager]))
assert(createClientCount === 1)
assert(hitExecutorDeadException)
}

private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = {
actualPort should be >= expectedPort
// avoid testing equality in case of simultaneous tests
Expand All @@ -85,13 +135,15 @@ class NettyBlockTransferServiceSuite
actualPort should be <= (expectedPort + 100)
}

private def createService(port: Int): NettyBlockTransferService = {
private def createService(
port: Int,
rpcEndpointRef: RpcEndpointRef = null): NettyBlockTransferService = {
val conf = new SparkConf()
.set("spark.app.id", s"test-${getClass.getName}")
val securityManager = new SecurityManager(conf)
val blockDataManager = mock(classOf[BlockDataManager])
val service = new NettyBlockTransferService(conf, securityManager, "localhost", "localhost",
port, 1)
port, 1, rpcEndpointRef)
service.init(blockDataManager)
service
}
Expand Down