Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 16 additions & 48 deletions core/src/main/scala/org/apache/spark/Accumulable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class Accumulable[R, T] private (
param: AccumulableParam[R, T],
name: Option[String],
countFailedValues: Boolean) = {
this(Accumulators.newId(), initialValue, param, name, countFailedValues)
this(AccumulatorContext.newId(), initialValue, param, name, countFailedValues)
}

private[spark] def this(initialValue: R, param: AccumulableParam[R, T], name: Option[String]) = {
Expand All @@ -72,61 +72,44 @@ class Accumulable[R, T] private (

def this(initialValue: R, param: AccumulableParam[R, T]) = this(initialValue, param, None)

@volatile @transient private var value_ : R = initialValue // Current value on driver
val zero = param.zero(initialValue) // Zero value to be passed to executors
private var deserialized = false

Accumulators.register(this)

/**
* Return a copy of this [[Accumulable]].
*
* The copy will have the same ID as the original and will not be registered with
* [[Accumulators]] again. This method exists so that the caller can avoid passing the
* same mutable instance around.
*/
private[spark] def copy(): Accumulable[R, T] = {
new Accumulable[R, T](id, initialValue, param, name, countFailedValues)
}
val zero = param.zero(initialValue)
private[spark] val newAcc = new LegacyAccumulatorWrapper(initialValue, param)
newAcc.metadata = AccumulatorMetadata(id, name, countFailedValues)
// Register the new accumulator in ctor, to follow the previous behaviour.
AccumulatorContext.register(newAcc)

/**
* Add more data to this accumulator / accumulable
* @param term the data to add
*/
def += (term: T) { value_ = param.addAccumulator(value_, term) }
def += (term: T) { newAcc.add(term) }

/**
* Add more data to this accumulator / accumulable
* @param term the data to add
*/
def add(term: T) { value_ = param.addAccumulator(value_, term) }
def add(term: T) { newAcc.add(term) }

/**
* Merge two accumulable objects together
*
* Normally, a user will not want to use this version, but will instead call `+=`.
* @param term the other `R` that will get merged with this
*/
def ++= (term: R) { value_ = param.addInPlace(value_, term)}
def ++= (term: R) { newAcc._value = param.addInPlace(newAcc._value, term) }

/**
* Merge two accumulable objects together
*
* Normally, a user will not want to use this version, but will instead call `add`.
* @param term the other `R` that will get merged with this
*/
def merge(term: R) { value_ = param.addInPlace(value_, term)}
def merge(term: R) { newAcc._value = param.addInPlace(newAcc._value, term) }

/**
* Access the accumulator's current value; only allowed on driver.
*/
def value: R = {
if (!deserialized) {
value_
} else {
throw new UnsupportedOperationException("Can't read accumulator value in task")
}
}
def value: R = newAcc.value

/**
* Get the current value of this accumulator from within a task.
Expand All @@ -137,14 +120,14 @@ class Accumulable[R, T] private (
* The typical use of this method is to directly mutate the local value, eg., to add
* an element to a Set.
*/
def localValue: R = value_
def localValue: R = newAcc.localValue

/**
* Set the accumulator's value; only allowed on driver.
*/
def value_= (newValue: R) {
if (!deserialized) {
value_ = newValue
if (newAcc.isAtDriverSide) {
newAcc._value = newValue
} else {
throw new UnsupportedOperationException("Can't assign accumulator value in task")
}
Expand All @@ -153,7 +136,7 @@ class Accumulable[R, T] private (
/**
* Set the accumulator's value. For internal use only.
*/
def setValue(newValue: R): Unit = { value_ = newValue }
def setValue(newValue: R): Unit = { newAcc._value = newValue }

/**
* Set the accumulator's value. For internal use only.
Expand All @@ -168,22 +151,7 @@ class Accumulable[R, T] private (
new AccumulableInfo(id, name, update, value, isInternal, countFailedValues)
}

// Called by Java when deserializing an object
private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
value_ = zero
deserialized = true

// Automatically register the accumulator when it is deserialized with the task closure.
// This is for external accumulators and internal ones that do not represent task level
// metrics, e.g. internal SQL metrics, which are per-operator.
val taskContext = TaskContext.get()
if (taskContext != null) {
taskContext.registerAccumulator(this)
}
}

override def toString: String = if (value_ == null) "null" else value_.toString
override def toString: String = if (newAcc._value == null) "null" else newAcc._value.toString
}


Expand Down
67 changes: 0 additions & 67 deletions core/src/main/scala/org/apache/spark/Accumulator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,73 +68,6 @@ class Accumulator[T] private[spark] (
extends Accumulable[T, T](initialValue, param, name, countFailedValues)


// TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right
private[spark] object Accumulators extends Logging {
/**
* This global map holds the original accumulator objects that are created on the driver.
* It keeps weak references to these objects so that accumulators can be garbage-collected
* once the RDDs and user-code that reference them are cleaned up.
* TODO: Don't use a global map; these should be tied to a SparkContext (SPARK-13051).
*/
@GuardedBy("Accumulators")
val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]()

private val nextId = new AtomicLong(0L)

/**
* Return a globally unique ID for a new [[Accumulable]].
* Note: Once you copy the [[Accumulable]] the ID is no longer unique.
*/
def newId(): Long = nextId.getAndIncrement

/**
* Register an [[Accumulable]] created on the driver such that it can be used on the executors.
*
* All accumulators registered here can later be used as a container for accumulating partial
* values across multiple tasks. This is what [[org.apache.spark.scheduler.DAGScheduler]] does.
* Note: if an accumulator is registered here, it should also be registered with the active
* context cleaner for cleanup so as to avoid memory leaks.
*
* If an [[Accumulable]] with the same ID was already registered, this does nothing instead
* of overwriting it. This happens when we copy accumulators, e.g. when we reconstruct
* [[org.apache.spark.executor.TaskMetrics]] from accumulator updates.
*/
def register(a: Accumulable[_, _]): Unit = synchronized {
if (!originals.contains(a.id)) {
originals(a.id) = new WeakReference[Accumulable[_, _]](a)
}
}

/**
* Unregister the [[Accumulable]] with the given ID, if any.
*/
def remove(accId: Long): Unit = synchronized {
originals.remove(accId)
}

/**
* Return the [[Accumulable]] registered with the given ID, if any.
*/
def get(id: Long): Option[Accumulable[_, _]] = synchronized {
originals.get(id).map { weakRef =>
// Since we are storing weak references, we must check whether the underlying data is valid.
weakRef.get.getOrElse {
throw new IllegalAccessError(s"Attempted to access garbage collected accumulator $id")
}
}
}

/**
* Clear all registered [[Accumulable]]s. For testing only.
*/
def clear(): Unit = synchronized {
originals.clear()
}

}


/**
* A simpler version of [[org.apache.spark.AccumulableParam]] where the only data type you can add
* in is the same type as the accumulated value. An implicit AccumulatorParam object needs to be
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
registerForCleanup(rdd, CleanRDD(rdd.id))
}

def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = {
def registerAccumulatorForCleanup(a: NewAccumulator[_, _]): Unit = {
registerForCleanup(a, CleanAccum(a.id))
}

Expand Down Expand Up @@ -241,7 +241,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
def doCleanupAccum(accId: Long, blocking: Boolean): Unit = {
try {
logDebug("Cleaning accumulator " + accId)
Accumulators.remove(accId)
AccumulatorContext.remove(accId)
listeners.asScala.foreach(_.accumCleaned(accId))
logInfo("Cleaned accumulator " + accId)
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils}
*/
private[spark] case class Heartbeat(
executorId: String,
accumUpdates: Array[(Long, Seq[AccumulableInfo])], // taskId -> accum updates
accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])], // taskId -> accumulator updates
blockManagerId: BlockManagerId)

/**
Expand Down
Loading