Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.executor

import java.util.{ArrayList, Collections}

import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, LinkedHashMap}

import org.apache.spark._
Expand Down Expand Up @@ -99,7 +102,11 @@ class TaskMetrics private[spark] () extends Serializable {
/**
* Storage statuses of any blocks that have been updated as a result of this task.
*/
def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses.value
def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = {
// This is called on driver. All accumulator updates have a fixed value. So it's safe to use
// `asScala` which accesses the internal values using `java.util.Iterator`.
_updatedBlockStatuses.value.asScala
}

// Setters and increment-ers
private[spark] def setExecutorDeserializeTime(v: Long): Unit =
Expand All @@ -114,8 +121,10 @@ class TaskMetrics private[spark] () extends Serializable {
private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v)
private[spark] def incUpdatedBlockStatuses(v: (BlockId, BlockStatus)): Unit =
_updatedBlockStatuses.add(v)
private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit =
private[spark] def setUpdatedBlockStatuses(v: java.util.List[(BlockId, BlockStatus)]): Unit =
_updatedBlockStatuses.setValue(v)
private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit =
_updatedBlockStatuses.setValue(v.asJava)

/**
* Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted
Expand Down Expand Up @@ -268,7 +277,7 @@ private[spark] object TaskMetrics extends Logging {
val name = info.name.get
val value = info.update.get
if (name == UPDATED_BLOCK_STATUSES) {
tm.setUpdatedBlockStatuses(value.asInstanceOf[Seq[(BlockId, BlockStatus)]])
tm.setUpdatedBlockStatuses(value.asInstanceOf[java.util.List[(BlockId, BlockStatus)]])
} else {
tm.nameToAccums.get(name).foreach(
_.asInstanceOf[LongAccumulator].setValue(value.asInstanceOf[Long])
Expand Down Expand Up @@ -299,34 +308,36 @@ private[spark] object TaskMetrics extends Logging {


private[spark] class BlockStatusesAccumulator
extends AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]] {
private var _seq = ArrayBuffer.empty[(BlockId, BlockStatus)]
extends AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, BlockStatus)]] {
private val _seq = Collections.synchronizedList(new ArrayList[(BlockId, BlockStatus)]())

override def isZero(): Boolean = _seq.isEmpty

override def copyAndReset(): BlockStatusesAccumulator = new BlockStatusesAccumulator

override def copy(): BlockStatusesAccumulator = {
val newAcc = new BlockStatusesAccumulator
newAcc._seq = _seq.clone()
newAcc._seq.addAll(_seq)
newAcc
}

override def reset(): Unit = _seq.clear()

override def add(v: (BlockId, BlockStatus)): Unit = _seq += v
override def add(v: (BlockId, BlockStatus)): Unit = _seq.add(v)

override def merge(other: AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]])
: Unit = other match {
case o: BlockStatusesAccumulator => _seq ++= o.value
case _ => throw new UnsupportedOperationException(
s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
override def merge(
other: AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, BlockStatus)]]): Unit = {
other match {
case o: BlockStatusesAccumulator => _seq.addAll(o.value)
case _ => throw new UnsupportedOperationException(
s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
}

override def value: Seq[(BlockId, BlockStatus)] = _seq
override def value: java.util.List[(BlockId, BlockStatus)] = _seq

def setValue(newValue: Seq[(BlockId, BlockStatus)]): Unit = {
def setValue(newValue: java.util.List[(BlockId, BlockStatus)]): Unit = {
_seq.clear()
_seq ++= newValue
_seq.addAll(newValue)
}
}
7 changes: 5 additions & 2 deletions core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.util

import java.{lang => jl}
import java.io.ObjectInputStream
import java.util.ArrayList
import java.util.{ArrayList, Collections}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong

Expand All @@ -38,6 +38,9 @@ private[spark] case class AccumulatorMetadata(
/**
* The base class for accumulators, that can accumulate inputs of type `IN`, and produce output of
* type `OUT`.
*
* `OUT` should be a type that can be read atomically (e.g., Int, Long), or thread-safely
* (e.g., synchronized collections) because it will be read from other threads.
*/
abstract class AccumulatorV2[IN, OUT] extends Serializable {
private[spark] var metadata: AccumulatorMetadata = _
Expand Down Expand Up @@ -433,7 +436,7 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
* @since 2.0.0
*/
class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] {
private val _list: java.util.List[T] = new ArrayList[T]
private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]())

override def isZero: Boolean = _list.isEmpty

Expand Down
11 changes: 6 additions & 5 deletions core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,12 @@ private[spark] object JsonProtocol {
case v: Int => JInt(v)
case v: Long => JInt(v)
// We only have 3 kind of internal accumulator types, so if it's not int or long, it must be
// the blocks accumulator, whose type is `Seq[(BlockId, BlockStatus)]`
// the blocks accumulator, whose type is `java.util.List[(BlockId, BlockStatus)]`
case v =>
JArray(v.asInstanceOf[Seq[(BlockId, BlockStatus)]].toList.map { case (id, status) =>
("Block ID" -> id.toString) ~
("Status" -> blockStatusToJson(status))
JArray(v.asInstanceOf[java.util.List[(BlockId, BlockStatus)]].asScala.toList.map {
case (id, status) =>
("Block ID" -> id.toString) ~
("Status" -> blockStatusToJson(status))
})
}
} else {
Expand Down Expand Up @@ -743,7 +744,7 @@ private[spark] object JsonProtocol {
val id = BlockId((blockJson \ "Block ID").extract[String])
val status = blockStatusFromJson(blockJson \ "Status")
(id, status)
}
}.asJava
case _ => throw new IllegalArgumentException(s"unexpected json value $value for " +
"accumulator " + name.get)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.util

import java.util.Properties

import scala.collection.JavaConverters._
import scala.collection.Map

import org.json4s.jackson.JsonMethods._
Expand Down Expand Up @@ -415,7 +416,7 @@ class JsonProtocolSuite extends SparkFunSuite {
})
testAccumValue(Some(RESULT_SIZE), 3L, JInt(3))
testAccumValue(Some(shuffleRead.REMOTE_BLOCKS_FETCHED), 2, JInt(2))
testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks, blocksJson)
testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks.asJava, blocksJson)
// For anything else, we just cast the value to a string
testAccumValue(Some("anything"), blocks, JString(blocks.toString))
testAccumValue(Some("anything"), 123, JString("123"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.spark.sql.execution

import scala.collection.mutable.HashSet
import java.util.Collections

import scala.collection.JavaConverters._

import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -107,18 +109,20 @@ package object debug {
case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport {
def output: Seq[Attribute] = child.output

class SetAccumulator[T] extends AccumulatorV2[T, HashSet[T]] {
private val _set = new HashSet[T]()
class SetAccumulator[T] extends AccumulatorV2[T, java.util.Set[T]] {
private val _set = Collections.synchronizedSet(new java.util.HashSet[T]())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you use Collections.synchronized*, will serialization of those objects also be thread-safe (i.e. will writeObject synchronize properly)? What about if Kryo is used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Java serialization, it's synchronized. See: http://www.grepcode.com/file/repository.grepcode.com/java/root/jdk/openjdk/8u40-b25/java/util/Collections.java#2080

Do we use Kryo to serialize Heartbeat?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, it looks like NettyRpcEnv.serialize is hardcoded to use a JavaSerializer instance.

override def isZero: Boolean = _set.isEmpty
override def copy(): AccumulatorV2[T, HashSet[T]] = {
override def copy(): AccumulatorV2[T, java.util.Set[T]] = {
val newAcc = new SetAccumulator[T]()
newAcc._set ++= _set
newAcc._set.addAll(_set)
newAcc
}
override def reset(): Unit = _set.clear()
override def add(v: T): Unit = _set += v
override def merge(other: AccumulatorV2[T, HashSet[T]]): Unit = _set ++= other.value
override def value: HashSet[T] = _set
override def add(v: T): Unit = _set.add(v)
override def merge(other: AccumulatorV2[T, java.util.Set[T]]): Unit = {
_set.addAll(other.value)
}
override def value: java.util.Set[T] = _set
}

/**
Expand All @@ -138,7 +142,9 @@ package object debug {
debugPrint(s"== ${child.simpleString} ==")
debugPrint(s"Tuples output: ${tupleCount.value}")
child.output.zip(columnStats).foreach { case (attr, metric) =>
val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}")
// This is called on driver. All accumulator updates have a fixed value. So it's safe to use
// `asScala` which accesses the internal values using `java.util.Iterator`.
val actualDataTypes = metric.elementTypes.value.asScala.mkString("{", ",", "}")
debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
}
}
Expand Down