Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.network

import scala.reflect.ClassTag

import org.apache.spark.TaskContext
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.client.StreamCallbackWithID
import org.apache.spark.storage.{BlockId, StorageLevel}
Expand Down Expand Up @@ -58,5 +59,5 @@ trait BlockDataManager {
/**
* Release locks acquired by [[putBlockData()]] and [[getBlockData()]].
*/
def releaseLock(blockId: BlockId, taskAttemptId: Option[Long]): Unit
def releaseLock(blockId: BlockId, taskContext: Option[TaskContext]): Unit
}
27 changes: 17 additions & 10 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ private[spark] class BlockManager(
case Some(info) =>
val level = info.level
logDebug(s"Level for block $blockId is $level")
val taskAttemptId = Option(TaskContext.get()).map(_.taskAttemptId())
val taskContext = Option(TaskContext.get())
if (level.useMemory && memoryStore.contains(blockId)) {
val iter: Iterator[Any] = if (level.deserialized) {
memoryStore.getValues(blockId).get
Expand All @@ -743,7 +743,7 @@ private[spark] class BlockManager(
// from a different thread which does not have TaskContext set; see SPARK-18406 for
// discussion.
val ci = CompletionIterator[Any, Iterator[Any]](iter, {
releaseLock(blockId, taskAttemptId)
releaseLock(blockId, taskContext)
})
Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
} else if (level.useDisk && diskStore.contains(blockId)) {
Expand All @@ -762,7 +762,7 @@ private[spark] class BlockManager(
}
}
val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, {
releaseLockAndDispose(blockId, diskData, taskAttemptId)
releaseLockAndDispose(blockId, diskData, taskContext)
})
Some(new BlockResult(ci, DataReadMethod.Disk, info.size))
} else {
Expand Down Expand Up @@ -991,13 +991,20 @@ private[spark] class BlockManager(
}

/**
* Release a lock on the given block with explicit TID.
* The param `taskAttemptId` should be passed in case we can't get the correct TID from
* TaskContext, for example, the input iterator of a cached RDD iterates to the end in a child
* Release a lock on the given block with explicit TaskContext.
* The param `taskContext` should be passed in case we can't get the correct TaskContext,
* for example, the input iterator of a cached RDD iterates to the end in a child
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: a missing , before for example.

* thread.
*/
def releaseLock(blockId: BlockId, taskAttemptId: Option[Long] = None): Unit = {
blockInfoManager.unlock(blockId, taskAttemptId)
def releaseLock(blockId: BlockId, taskContext: Option[TaskContext] = None): Unit = {
val taskAttemptId = taskContext.map(_.taskAttemptId())
// SPARK-27666. When a task completes, Spark automatically releases all the blocks locked
// by this task. We should not release any locks for a task that is already completed.
if (taskContext.isDefined && taskContext.get.isCompleted) {
logWarning(s"Task ${taskAttemptId.get} already completed, not releasing lock for $blockId")
} else {
blockInfoManager.unlock(blockId, taskAttemptId)
}
}

/**
Expand Down Expand Up @@ -1666,8 +1673,8 @@ private[spark] class BlockManager(
def releaseLockAndDispose(
blockId: BlockId,
data: BlockData,
taskAttemptId: Option[Long] = None): Unit = {
releaseLock(blockId, taskAttemptId)
taskContext: Option[TaskContext] = None): Unit = {
releaseLock(blockId, taskContext)
data.dispose()
}

Expand Down
30 changes: 29 additions & 1 deletion core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,25 @@
package org.apache.spark.rdd

import java.io.{File, IOException, ObjectInputStream, ObjectOutputStream}
import java.lang.management.ManagementFactory

import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.duration._
import scala.reflect.ClassTag

import com.esotericsoftware.kryo.KryoException
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapred.{FileSplit, TextInputFormat}
import org.scalatest.concurrent.Eventually

import org.apache.spark._
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.internal.config.RDD_PARALLEL_LISTING_THRESHOLD
import org.apache.spark.rdd.RDDSuiteUtils._
import org.apache.spark.util.{ThreadUtils, Utils}

class RDDSuite extends SparkFunSuite with SharedSparkContext {
class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually {
var tempDir: File = _

override def beforeAll(): Unit = {
Expand Down Expand Up @@ -1176,6 +1179,31 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
}.collect()
}

test("SPARK-27666: Do not release lock while TaskContext already completed") {
val rdd = sc.parallelize(Range(0, 10), 1).cache()
val tid = sc.longAccumulator("threadId")
// validate cache
rdd.collect()
rdd.mapPartitions { iter =>
val t = new Thread(() => {
while (iter.hasNext) {
iter.next()
Thread.sleep(100)
}
})
t.setDaemon(false)
t.start()
tid.add(t.getId)
Iterator(0)
}.collect()
val tmx = ManagementFactory.getThreadMXBean
eventually(timeout(10.seconds)) {
// getThreadInfo() will return null after child thread `t` died
val t = tmx.getThreadInfo(tid.value)
assert(t == null || t.getThreadState == Thread.State.TERMINATED)
}
}
Copy link
Member

@viirya viirya Jun 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use eventually with timeout instead of a while loop?

eventually(timeout(10.seconds))  {
  val t = tmx.getThreadInfo(tid.value)
  assert(t == null || t.getThreadState == Thread.State.TERMINATED)
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea!


test("SPARK-23496: order of input partitions can result in severe skew in coalesce") {
val numInputPartitions = 100
val numCoalescedPartitions = 50
Expand Down