diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 0dfadb657b77..43771180bb31 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1724,15 +1724,23 @@ private[spark] class BlockManager( * lock on the block. */ private def removeBlockInternal(blockId: BlockId, tellMaster: Boolean): Unit = { + val blockStatus = if (tellMaster) { + val blockInfo = blockInfoManager.assertBlockIsLockedForWriting(blockId) + Some(getCurrentBlockStatus(blockId, blockInfo)) + } else None + // Removals are idempotent in disk store and memory store. At worst, we get a warning. val removedFromMemory = memoryStore.remove(blockId) val removedFromDisk = diskStore.remove(blockId) if (!removedFromMemory && !removedFromDisk) { logWarning(s"Block $blockId could not be removed as it was not found on disk or in memory") } + blockInfoManager.removeBlock(blockId) if (tellMaster) { - reportBlockStatus(blockId, BlockStatus.empty) + // Only update storage level from the captured block status before deleting, so that + // memory size and disk size are being kept for calculating delta. + reportBlockStatus(blockId, blockStatus.get.copy(storageLevel = StorageLevel.NONE)) } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 509d4efcab67..07c23cfb0e8e 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -28,8 +28,8 @@ import scala.language.implicitConversions import scala.reflect.ClassTag import org.apache.commons.lang3.RandomUtils -import org.mockito.{ArgumentMatchers => mc} -import org.mockito.Mockito.{doAnswer, mock, spy, times, verify, when} +import org.mockito.{ArgumentCaptor, ArgumentMatchers => mc} +import org.mockito.Mockito.{doAnswer, mock, never, spy, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.scalatest._ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} @@ -143,9 +143,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // need to create a SparkContext is to initialize LiveListenerBus. sc = mock(classOf[SparkContext]) when(sc.conf).thenReturn(conf) - master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", - new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(conf), None)), conf, true) + master = spy(new BlockManagerMaster( + rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, + new LiveListenerBus(conf), None)), conf, true)) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -289,14 +290,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE eventually(timeout(1.second), interval(10.milliseconds)) { assert(!store.hasLocalBlock("a1-to-remove")) master.getLocations("a1-to-remove") should have size 0 + assertUpdateBlockInfoReportedForRemovingBlock(store, "a1-to-remove", + removedFromMemory = true, removedFromDisk = false) } eventually(timeout(1.second), interval(10.milliseconds)) { assert(!store.hasLocalBlock("a2-to-remove")) master.getLocations("a2-to-remove") should have size 0 + assertUpdateBlockInfoReportedForRemovingBlock(store, "a2-to-remove", + removedFromMemory = true, removedFromDisk = false) } eventually(timeout(1.second), interval(10.milliseconds)) { assert(store.hasLocalBlock("a3-to-remove")) master.getLocations("a3-to-remove") should have size 0 + assertUpdateBlockInfoNotReported(store, "a3-to-remove") } eventually(timeout(1.second), interval(10.milliseconds)) { val memStatus = master.getMemoryStatus.head._2 @@ -375,16 +381,21 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!executorStore.hasLocalBlock(broadcast0BlockId)) assert(executorStore.hasLocalBlock(broadcast1BlockId)) assert(executorStore.hasLocalBlock(broadcast2BlockId)) + assertUpdateBlockInfoReportedForRemovingBlock(executorStore, broadcast0BlockId, + removedFromMemory = false, removedFromDisk = true) // nothing should be removed from the driver store assert(driverStore.hasLocalBlock(broadcast0BlockId)) assert(driverStore.hasLocalBlock(broadcast1BlockId)) assert(driverStore.hasLocalBlock(broadcast2BlockId)) + assertUpdateBlockInfoNotReported(driverStore, broadcast0BlockId) // remove broadcast 0 block from the driver as well master.removeBroadcast(0, removeFromMaster = true, blocking = true) assert(!driverStore.hasLocalBlock(broadcast0BlockId)) assert(driverStore.hasLocalBlock(broadcast1BlockId)) + assertUpdateBlockInfoReportedForRemovingBlock(driverStore, broadcast0BlockId, + removedFromMemory = false, removedFromDisk = true) // remove broadcast 1 block from both the stores asynchronously // and verify all broadcast 1 blocks have been removed @@ -392,6 +403,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE eventually(timeout(1.second), interval(10.milliseconds)) { assert(!driverStore.hasLocalBlock(broadcast1BlockId)) assert(!executorStore.hasLocalBlock(broadcast1BlockId)) + assertUpdateBlockInfoReportedForRemovingBlock(driverStore, broadcast1BlockId, + removedFromMemory = false, removedFromDisk = true) + assertUpdateBlockInfoReportedForRemovingBlock(executorStore, broadcast1BlockId, + removedFromMemory = false, removedFromDisk = true) } // remove broadcast 2 from both the stores asynchronously @@ -402,11 +417,46 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!driverStore.hasLocalBlock(broadcast2BlockId2)) assert(!executorStore.hasLocalBlock(broadcast2BlockId)) assert(!executorStore.hasLocalBlock(broadcast2BlockId2)) + assertUpdateBlockInfoReportedForRemovingBlock(driverStore, broadcast2BlockId, + removedFromMemory = false, removedFromDisk = true) + assertUpdateBlockInfoReportedForRemovingBlock(driverStore, broadcast2BlockId2, + removedFromMemory = false, removedFromDisk = true) + assertUpdateBlockInfoReportedForRemovingBlock(executorStore, broadcast2BlockId, + removedFromMemory = false, removedFromDisk = true) + assertUpdateBlockInfoReportedForRemovingBlock(executorStore, broadcast2BlockId2, + removedFromMemory = false, removedFromDisk = true) } executorStore.stop() driverStore.stop() } + private def assertUpdateBlockInfoReportedForRemovingBlock( + store: BlockManager, + blockId: BlockId, + removedFromMemory: Boolean, + removedFromDisk: Boolean): Unit = { + def assertSizeReported(captor: ArgumentCaptor[Long], expectRemoved: Boolean): Unit = { + assert(captor.getAllValues().size() === 1) + if (expectRemoved) { + assert(captor.getValue() > 0) + } else { + assert(captor.getValue() === 0) + } + } + + val memSizeCaptor = ArgumentCaptor.forClass(classOf[Long]).asInstanceOf[ArgumentCaptor[Long]] + val diskSizeCaptor = ArgumentCaptor.forClass(classOf[Long]).asInstanceOf[ArgumentCaptor[Long]] + verify(master).updateBlockInfo(mc.eq(store.blockManagerId), mc.eq(blockId), + mc.eq(StorageLevel.NONE), memSizeCaptor.capture(), diskSizeCaptor.capture()) + assertSizeReported(memSizeCaptor, removedFromMemory) + assertSizeReported(diskSizeCaptor, removedFromDisk) + } + + private def assertUpdateBlockInfoNotReported(store: BlockManager, blockId: BlockId): Unit = { + verify(master, never()).updateBlockInfo(mc.eq(store.blockManagerId), mc.eq(blockId), + mc.eq(StorageLevel.NONE), mc.anyInt(), mc.anyInt()) + } + test("reregistration on heart beat") { val store = makeBlockManager(2000) val a1 = new Array[Byte](400)