diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 700e2cb3f91b..3ed5f6d72160 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -70,6 +70,7 @@ class SparkEnv ( val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { + var currentStage: Int = -1 private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() diff --git a/core/src/main/scala/org/apache/spark/StageExInfo.scala b/core/src/main/scala/org/apache/spark/StageExInfo.scala new file mode 100644 index 000000000000..6eea5a5da9a7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/StageExInfo.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable + +/** + * DS to store info of a stage. + */ +class StageExInfo(val stageId: Int, + val alreadyPerRddSet: Set[Int], // prs + val afterPerRddSet: Set[Int], // aprs + val depMap: mutable.HashMap[Int, Set[Int]], + val curRunningRddMap: mutable.HashMap[Int, Set[Int]]) { + +} diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 81e41e6fa715..b65b4b599d15 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -92,6 +92,12 @@ private[spark] class CoarseGrainedExecutorBackend( } else { val taskDesc = ser.deserialize[TaskDescription](data.value) logInfo("Got assigned task " + taskDesc.taskId) + val currentStageId = taskDesc.name.substring(taskDesc.name.lastIndexOf(' ') + 1, + taskDesc.name.lastIndexOf('.')).toInt + env.currentStage = currentStageId + env.blockManager.currentStage = currentStageId + // logEarne("this Stage has ExInfo: " + env.stageExInfos(currentStageId)) + executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber, taskDesc.name, taskDesc.serializedTask) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 09c57335650c..41bba82ebce5 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -220,7 +220,10 @@ private[spark] class Executor( // for the task. throw new TaskKilledException } - + if (!env.blockManager.stageExInfos.contains(task.stageId)) { + env.blockManager.stageExInfos.put(task.stageId, + new StageExInfo(task.stageId, null, null, task.depMap, task.curRunningRddMap)) + } logDebug("Task " + taskId + "'s epoch is " + task.epoch) env.mapOutputTracker.updateEpoch(task.epoch) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 4a0a2199ef7e..1f3d2d4b44eb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -38,7 +38,7 @@ import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult -import org.apache.spark.storage.{RDDBlockId, StorageLevel} +import org.apache.spark.storage.{BlockExInfo, RDDBlockId, StorageLevel} import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, @@ -208,6 +208,7 @@ abstract class RDD[T: ClassTag]( */ def unpersist(blocking: Boolean = true): this.type = { logInfo("Removing RDD " + id + " from persistence list") + sc.dagScheduler.renewDepMap(id) sc.unpersistRDD(id, blocking) storageLevel = StorageLevel.NONE this @@ -307,6 +308,31 @@ abstract class RDD[T: ClassTag]( ancestors.filterNot(_ == this).toSeq } + /** + * Return the ancestors + */ + private[spark] def getNarrowCachedAncestors: Set[Int] = { + val cachedAncestors = new mutable.HashSet[Int] + val ancestors = new mutable.HashSet[RDD[_]] + def visit(rdd: RDD[_]): Unit = { + val narrowDependencies = rdd.dependencies.filter(_.isInstanceOf[NarrowDependency[_]]) + val narrowParents = narrowDependencies.map(_.rdd) + val narrowParentsNotVisited = narrowParents.filterNot(ancestors.contains) + narrowParentsNotVisited.foreach { parent => + ancestors.add(parent) + if (parent.getStorageLevel != StorageLevel.NONE) { + cachedAncestors.add(parent.id) + } else { + visit(parent) + } + } + } + + visit(this) + + cachedAncestors.filterNot(_ == this.id).toSet + } + /** * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing. */ @@ -328,6 +354,39 @@ abstract class RDD[T: ClassTag]( // This method is called on executors, so we need call SparkEnv.get instead of sc.env. SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => { readCachedBlock = false + val key = blockId + logInfo(s"Partition $key not found, computing it") + + val blockManager = SparkEnv.get.blockManager + + if (!blockManager.blockExInfo.containsKey(key)) { + blockManager.blockExInfo.put(key, new BlockExInfo(key)) + } + + blockManager.stageExInfos.get(blockManager.currentStage) match { + case Some(curStageExInfo) => + var parExist = true + for (par <- curStageExInfo.depMap(id)) { + val parBlockId = new RDDBlockId(par, partition.index) + if (blockManager.blockExInfo.containsKey(parBlockId) && + blockManager.blockExInfo.get(parBlockId).isExist + == 1) { // par is exist + + } else { // par not exist now, add this key to it's par's watching set + parExist = false + if (!blockManager.blockExInfo.containsKey(parBlockId)) { + blockManager.blockExInfo.put(parBlockId, new BlockExInfo(parBlockId)) + } + blockManager.blockExInfo.get(parBlockId).sonSet += key + } + } + if (parExist) { // par are all exist so we update this rdd's start time + logTrace("par all exist, store start time of " + key) + blockManager.blockExInfo.get(key).creatStartTime = System.currentTimeMillis() + } + case None => + logError("Some Thing Wrong") + } computeOrReadCheckpoint(partition, context) }) match { case Left(blockResult) => @@ -483,8 +542,7 @@ abstract class RDD[T: ClassTag]( * * @param weights weights for splits, will be normalized if they don't sum to 1 * @param seed random seed - * - * @return split RDDs in an array + * @return split RDDs in an array */ def randomSplit( weights: Array[Double], @@ -499,7 +557,8 @@ abstract class RDD[T: ClassTag]( /** * Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability * range. - * @param lb lower bound to use for the Bernoulli sampler + * + * @param lb lower bound to use for the Bernoulli sampler * @param ub upper bound to use for the Bernoulli sampler * @param seed the seed for the Random number generator * @return A random sub-sample of the RDD without replacement. @@ -517,8 +576,7 @@ abstract class RDD[T: ClassTag]( * * @note this method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. - * - * @param withReplacement whether sampling is done with replacement + * @param withReplacement whether sampling is done with replacement * @param num size of the returned sample * @param seed seed for the random number generator * @return sample of specified size in an array @@ -1244,8 +1302,7 @@ abstract class RDD[T: ClassTag]( * * @note this method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. - * - * @note due to complications in the internal implementation, this method will raise + * @note due to complications in the internal implementation, this method will raise * an exception if called on an RDD of `Nothing` or `Null`. */ def take(num: Int): Array[T] = withScope { @@ -1308,8 +1365,7 @@ abstract class RDD[T: ClassTag]( * * @note this method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. - * - * @param num k, the number of top elements to return + * @param num k, the number of top elements to return * @param ord the implicit ordering for T * @return an array of top elements */ @@ -1331,8 +1387,7 @@ abstract class RDD[T: ClassTag]( * * @note this method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. - * - * @param num k, the number of elements to return + * @param num k, the number of elements to return * @param ord the implicit ordering for T * @return an array of top elements */ @@ -1359,7 +1414,8 @@ abstract class RDD[T: ClassTag]( /** * Returns the max of this RDD as defined by the implicit Ordering[T]. - * @return the maximum element of the RDD + * + * @return the maximum element of the RDD * */ def max()(implicit ord: Ordering[T]): T = withScope { this.reduce(ord.max) @@ -1367,7 +1423,8 @@ abstract class RDD[T: ClassTag]( /** * Returns the min of this RDD as defined by the implicit Ordering[T]. - * @return the minimum element of the RDD + * + * @return the minimum element of the RDD * */ def min()(implicit ord: Ordering[T]): T = withScope { this.reduce(ord.min) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5cdc91316b69..d4d123c4d354 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -155,6 +155,11 @@ class DAGScheduler( private[scheduler] val activeJobs = new HashSet[ActiveJob] + private[scheduler] var preRDDs = new HashSet[RDD[_]] + + private[scheduler] var depMap = new HashMap[Int, Set[Int]] + + private[scheduler] var curRunningRddMap = new HashMap[Int, Set[Int]] /** * Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids * and its values are arrays indexed by partition numbers. Each array value is the set of @@ -554,11 +559,9 @@ class DAGScheduler( * @param callSite where in the user program this job was called * @param resultHandler callback to pass each result to * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name - * - * @return a JobWaiter object that can be used to block until the job finishes executing + * @return a JobWaiter object that can be used to block until the job finishes executing * or can be used to cancel the job. - * - * @throws IllegalArgumentException when partitions ids are illegal + * @throws IllegalArgumentException when partitions ids are illegal */ def submitJob[T, U]( rdd: RDD[T], @@ -601,8 +604,7 @@ class DAGScheduler( * @param callSite where in the user program this job was called * @param resultHandler callback to pass each result to * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name - * - * @throws Exception when the job fails + * @throws Exception when the job fails */ def runJob[T, U]( rdd: RDD[T], @@ -928,6 +930,16 @@ class DAGScheduler( logDebug("missing: " + missing) if (missing.isEmpty) { logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") + + val curRDDs = stage.rdd.getNarrowAncestors ++ Seq(stage.rdd) + val newRDDs = curRDDs.filter(!preRDDs.contains(_)) + val newCachedRDDs = newRDDs.filter(_.getStorageLevel != StorageLevel.NONE) + curRunningRddMap.clear() + newCachedRDDs.foreach { cachedRdd => + depMap.put(cachedRdd.id, cachedRdd.getNarrowCachedAncestors) + curRunningRddMap.put(cachedRdd.id, cachedRdd.getNarrowCachedAncestors) + } + preRDDs = preRDDs ++ curRDDs submitMissingTasks(stage, jobId.get) } else { for (parent <- missing) { @@ -941,6 +953,22 @@ class DAGScheduler( } } + /** Renew depMap when unpersist RDD */ + def renewDepMap(id: Int): Unit = { + if (depMap.contains(id)) { + logTrace("Remove RDD " + id + " from depMap") + val value = depMap(id) + depMap.foreach { rdd => + if (rdd._2.contains(id)) { + val tmp = rdd._2 - id + depMap.put(rdd._1, tmp ++ value) + } + } + depMap.remove(id) + logTrace("After Removed RDD " + id + " the depMap is " + depMap) + } + } + /** Called when stage's parents are available and we can now do its task. */ private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") @@ -1036,7 +1064,7 @@ class DAGScheduler( val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, stage.internalAccumulators) + taskBinary, part, locs, stage.internalAccumulators, depMap, curRunningRddMap) } case stage: ResultStage => @@ -1046,7 +1074,7 @@ class DAGScheduler( val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, id, stage.internalAccumulators) + taskBinary, part, locs, id, stage.internalAccumulators, depMap, curRunningRddMap) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index cd2736e1960c..b5e4e0aa93f7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -20,6 +20,8 @@ package org.apache.spark.scheduler import java.io._ import java.nio.ByteBuffer +import scala.collection.mutable.HashMap + import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -49,9 +51,11 @@ private[spark] class ResultTask[T, U]( partition: Partition, locs: Seq[TaskLocation], val outputId: Int, - _initialAccums: Seq[Accumulator[_]] = InternalAccumulator.createAll()) - extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums) - with Serializable { + _initialAccums: Seq[Accumulator[_]] = InternalAccumulator.createAll(), + depMap: HashMap[Int, Set[Int]] = null, + curRunningRddMap: HashMap[Int, Set[Int]] = null) + extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums, depMap, + curRunningRddMap) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index e30964a01bda..2925599642e5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import scala.collection.mutable.HashMap import scala.language.existentials import org.apache.spark._ @@ -49,13 +50,16 @@ private[spark] class ShuffleMapTask( taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient private var locs: Seq[TaskLocation], - _initialAccums: Seq[Accumulator[_]]) - extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums) + _initialAccums: Seq[Accumulator[_]], + depMap: HashMap[Int, Set[Int]], + curRunningRddMap: HashMap[Int, Set[Int]]) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums, depMap, + curRunningRddMap) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, 0, null, new Partition { override def index: Int = 0 }, null, null) + this(0, 0, null, new Partition { override def index: Int = 0 }, null, null, null, null) } @transient private val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 46c64f61de5f..595a4a664205 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -51,7 +51,10 @@ private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, val partitionId: Int, - val initialAccumulators: Seq[Accumulator[_]]) extends Serializable { + val initialAccumulators: Seq[Accumulator[_]], + var depMap: HashMap[Int, Set[Int]] = new HashMap[Int, Set[Int]], + var curRunningRddMap: HashMap[Int, Set[Int]] = + new HashMap[Int, Set[Int]]) extends Serializable { /** * Called by [[org.apache.spark.executor.Executor]] to run this task. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 524f6970992a..8959ea25579c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -46,6 +46,20 @@ sealed abstract class BlockId { case o: BlockId => getClass == o.getClass && name.equals(o.name) case _ => false } + + def getRddId: Int = { + asRDDId match { + case Some(x) => x.rddId + case _ => -1 + } + } + + def getRddSplitIndex: Int = { + asRDDId match { + case Some(x) => x.splitIndex + case _ => -1 + } + } } @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 9608418b435e..825fa0d9c408 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -83,6 +83,13 @@ private[spark] class BlockManager( // Visible for testing private[storage] val blockInfoManager = new BlockInfoManager + val blockExInfo = new java.util.HashMap[RDDBlockId, BlockExInfo] + + val inMemBlockExInfo = new java.util.TreeSet[BlockExInfo] + + var stageExInfos: HashMap[Int, StageExInfo] = new HashMap[Int, StageExInfo] + var currentStage: Int = -1 + private val futureExecutionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128)) @@ -1202,6 +1209,17 @@ private[spark] class BlockManager( blockId: BlockId, data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel = { logInfo(s"Dropping block $blockId from memory") + + if (blockId.isRDD) { + logTrace("change exist status of " + blockId + " to 0") + val curValue = blockExInfo.get(blockId) + curValue.isExist = 0 + inMemBlockExInfo.synchronized { + logTrace("Remove " + blockId + " from inMemBlockExInfo") + inMemBlockExInfo.remove(curValue) + } + } + val info = blockInfoManager.assertBlockIsLockedForWriting(blockId) var blockIsUpdated = false val level = info.level @@ -1253,6 +1271,60 @@ private[spark] class BlockManager( def removeRdd(rddId: Int): Int = { // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. logInfo(s"Removing RDD $rddId") + logTrace("Now we in Stage: " + currentStage) + logTrace(" and the depMap is: " + stageExInfos(currentStage).depMap) + + var targetRdd: Int = -1 + stageExInfos(currentStage).depMap.foreach { cur => + if (cur._2.contains(rddId)) { + targetRdd = cur._1 + } + } + + logTrace("Now we set targetRDD to: " + targetRdd + " becasue we removeRdd: " + rddId) + + val iter = blockExInfo.entrySet().iterator() + while (iter.hasNext) { + val cur = iter.next() + val curId = cur.getKey.getRddId + val splitIndex = cur.getKey.getRddSplitIndex + if (curId == targetRdd) { + inMemBlockExInfo.synchronized { + if (inMemBlockExInfo.contains(cur.getValue)) { + logTrace("Remove " + cur.getKey + " from inMemBlockExInfo") + inMemBlockExInfo.remove(cur.getValue) + cur.getValue.creatCost = cur.getValue.creatCost + blockExInfo.get( + new RDDBlockId(rddId, splitIndex)).creatCost + cur.getValue.norCost = cur.getValue. + creatCost.toDouble / (cur.getValue.size / 1024 / 1024) + + logTrace("Add " + cur.getValue.blockId + " to inMemBlockExInfo") + inMemBlockExInfo.add(cur.getValue) + + } else { + cur.getValue.creatCost = cur.getValue.creatCost + blockExInfo.get( + new RDDBlockId(rddId, splitIndex)).creatCost + cur.getValue.norCost = cur.getValue. + creatCost.toDouble / (cur.getValue.size / 1024 / 1024) + } + } + + logTrace("Due to Removing RDD " + rddId + " we have to change rdd_" + curId + "_" + + splitIndex + " ctime") + } else if (curId == rddId) { + val curValue = blockExInfo.get(new RDDBlockId(rddId, splitIndex)) + curValue.isExist = 0 + inMemBlockExInfo.synchronized { + logTrace("Remove " + curValue.blockId + " from inMemBlockExInfo") + inMemBlockExInfo.remove(curValue) + } + logTrace("Due to Removing RDD " + rddId + " we now change " + curId + "_" + + splitIndex + " to not exist and TreeSet have size " + inMemBlockExInfo.size()) + } + } + + // TODO + // we have to update the create cost of sonRDD when removing parRDD and sonRDD do not at remote val blocksToRemove = blockInfoManager.entries.flatMap(_._1.asRDDId).filter(_.rddId == rddId) blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } blocksToRemove.size diff --git a/core/src/main/scala/org/apache/spark/storage/memory/BlockExInfo.scala b/core/src/main/scala/org/apache/spark/storage/memory/BlockExInfo.scala new file mode 100644 index 000000000000..8693634702d8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/memory/BlockExInfo.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +class BlockExInfo(val blockId: RDDBlockId) extends Comparable[BlockExInfo] { + + var size: Long = _ + var creatStartTime: Long = _ + var creatFinTime: Long = _ + var creatCost: Long = _ + + var serStartTime: Long = _ + var serFinTime: Long = _ + var serCost: Long = 0 + var serAndDeCost: Long = _ + + var fakeSerCost: Long = 0 + + var isExist: Int = 0 + // 0: not exist; 1: in-memory; 2: ser in disk + var norCost: Double = _ // normalized cost + + var sonSet: Set[BlockId] = Set() + + // write the creatFinTime and cal the creatFinTime + def writeFinAndCalCreatCost(finTime: Long) { + creatFinTime = finTime + creatCost = creatFinTime - creatStartTime + norCost = creatCost.toDouble / (size / 1024 / 1024) + isExist = 1 + } + + def writeAndCalSerCost(serStart: Long, serFin: Long): Unit = { + serStartTime = serStart + serFinTime = serFin + serCost = serFinTime - serStartTime + isExist = 2 + } + + def decidePolicy: Int = { + if (creatCost < serAndDeCost) { + norCost = creatCost.toDouble / size + 3 // creat Cost is low so just remove from memory + } else { + norCost = serAndDeCost.toDouble / size + 4 // ser and deser cost is low, so just ser to disk + } + } + + override def compareTo(o: BlockExInfo): Int = { + this.norCost.compare(o.norCost) + } +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryEntryManager.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryEntryManager.scala new file mode 100644 index 000000000000..5ffa972c9c37 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryEntryManager.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage.memory + +import java.util + +import scala.collection.mutable.ArrayBuffer + +trait MemoryEntryManager[K, V] { + def getEntry(blockId: K): V + + def putEntry(key: K, value: V): V + + def removeEntry(key: K): V + + def clear() + + def containsEntry(key: K): Boolean +} + +class FIFOMemoryEntryManager[K, V] extends MemoryEntryManager[K, V] { + val entries = new util.LinkedHashMap[K, V](32, 0.75f) + + override def getEntry(key: K): V = { + entries.synchronized { + entries.get(key) + } + } + + override def putEntry(key: K, value: V): V = { + entries.synchronized { + entries.put(key, value) + } + } + + def clear() { + entries.synchronized { + entries.clear() + } + } + + override def removeEntry(key: K): V = { + entries.synchronized { + entries.remove(key) + } + } + + override def containsEntry(key: K): Boolean = { + entries.synchronized { + entries.containsKey(key) + } + } +} + +class LRUMemoryEntryManager[K, V] extends MemoryEntryManager[K, V] { + def entrySet() : util.Set[util.Map.Entry[K, V]] = { + entries.entrySet() + } + + val entries = new util.LinkedHashMap[K, V](32, 0.75f, true) + + override def getEntry(key: K): V = { + entries.synchronized { + entries.get(key) + } + } + + override def putEntry(key: K, value: V): V = { + entries.synchronized { + entries.put(key, value) + } + } + + def clear() { + entries.synchronized { + entries.clear() + } + } + + override def removeEntry(key: K): V = { + entries.synchronized { + entries.remove(key) + } + } + + override def containsEntry(key: K): Boolean = { + entries.synchronized { + entries.containsKey(key) + } + } +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 99be4de0658c..87bc1137115c 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -31,7 +31,7 @@ import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.serializer.{SerializationStream, SerializerManager} -import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel} +import org.apache.spark.storage.{BlockId, BlockInfoManager, BlockManager, StorageLevel} import org.apache.spark.unsafe.Platform import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector @@ -84,10 +84,12 @@ private[spark] class MemoryStore( blockEvictionHandler: BlockEvictionHandler) extends Logging { + val blockManager = blockEvictionHandler.asInstanceOf[BlockManager] // Note: all changes to memory allocations, notably putting blocks, evicting blocks, and // acquiring or releasing unroll memory, must be synchronized on `memoryManager`! - private val entries = new LinkedHashMap[BlockId, MemoryEntry[_]](32, 0.75f, true) +// private val entries = new LinkedHashMap[BlockId, MemoryEntry[_]](32, 0.75f, true) + private val entries = new LRUMemoryEntryManager[BlockId, MemoryEntry[_]] // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes) // All accesses of this map are assumed to have manually synchronized on `memoryManager` @@ -123,9 +125,7 @@ private[spark] class MemoryStore( } def getSize(blockId: BlockId): Long = { - entries.synchronized { - entries.get(blockId).size - } + entries.getEntry(blockId).size } /** @@ -147,9 +147,7 @@ private[spark] class MemoryStore( val bytes = _bytes() assert(bytes.size == size) val entry = new SerializedMemoryEntry[T](bytes, memoryMode, implicitly[ClassTag[T]]) - entries.synchronized { - entries.put(blockId, entry) - } + entries.putEntry(blockId, entry) logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format( blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) true @@ -264,9 +262,7 @@ private[spark] class MemoryStore( } } if (enoughStorageMemory) { - entries.synchronized { - entries.put(blockId, entry) - } + entries.putEntry(blockId, entry) logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format( blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) Right(size) @@ -373,9 +369,7 @@ private[spark] class MemoryStore( val success = memoryManager.acquireStorageMemory(blockId, entry.size, memoryMode) assert(success, "transferring unroll memory to storage memory failed") } - entries.synchronized { - entries.put(blockId, entry) - } + entries.putEntry(blockId, entry) logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format( blockId, Utils.bytesToString(entry.size), Utils.bytesToString(blocksMemoryUsed))) Right(entry.size) @@ -398,7 +392,7 @@ private[spark] class MemoryStore( } def getBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { - val entry = entries.synchronized { entries.get(blockId) } + val entry = entries.getEntry(blockId) entry match { case null => None case e: DeserializedMemoryEntry[_] => @@ -408,7 +402,7 @@ private[spark] class MemoryStore( } def getValues(blockId: BlockId): Option[Iterator[_]] = { - val entry = entries.synchronized { entries.get(blockId) } + val entry = entries.getEntry(blockId) entry match { case null => None case e: SerializedMemoryEntry[_] => @@ -420,9 +414,7 @@ private[spark] class MemoryStore( } def remove(blockId: BlockId): Boolean = memoryManager.synchronized { - val entry = entries.synchronized { - entries.remove(blockId) - } + val entry = entries.removeEntry(blockId) if (entry != null) { entry match { case SerializedMemoryEntry(buffer, _, _) => buffer.dispose() @@ -498,6 +490,29 @@ private[spark] class MemoryStore( } } + if (conf.get("***") == "LCS") { + blockManager.inMemBlockExInfo.synchronized { + val setIter = blockManager.inMemBlockExInfo.iterator() + while (freedMemory < space && setIter.hasNext) { + val cur = setIter.next() + + blockManager.stageExInfos.get(blockManager.currentStage) match { + case Some(curStageExInfo) => + // cur is this stage's output RDD + if (!curStageExInfo.curRunningRddMap.contains(cur.blockId.getRddId)) { + if (blockInfoManager.lockForWriting(cur.blockId, + blocking = false).isDefined) { + selectedBlocks += cur.blockId + freedMemory += cur.size + } + } + case None => + logError("ERROR HERE") + } + } + } + } + def dropBlock[T](blockId: BlockId, entry: MemoryEntry[T]): Unit = { val data = entry match { case DeserializedMemoryEntry(values, _, _) => Left(values) @@ -520,7 +535,7 @@ private[spark] class MemoryStore( logInfo(s"${selectedBlocks.size} blocks selected for dropping " + s"(${Utils.bytesToString(freedMemory)} bytes)") for (blockId <- selectedBlocks) { - val entry = entries.synchronized { entries.get(blockId) } + val entry = entries.getEntry(blockId) // This should never be null as only one task should be dropping // blocks and removing entries. However the check is still here for // future safety. @@ -544,7 +559,7 @@ private[spark] class MemoryStore( } def contains(blockId: BlockId): Boolean = { - entries.synchronized { entries.containsKey(blockId) } + entries.containsEntry(blockId) } private def currentTaskAttemptId(): Long = {