Skip to content

Commit bf5496d

Browse files
cloud-fanrxin
authored andcommitted
[SPARK-14654][CORE] New accumulator API
## What changes were proposed in this pull request? This PR introduces a new accumulator API which is much simpler than before: 1. the type hierarchy is simplified, now we only have an `Accumulator` class 2. Combine `initialValue` and `zeroValue` concepts into just one concept: `zeroValue` 3. there in only one `register` method, the accumulator registration and cleanup registration are combined. 4. the `id`,`name` and `countFailedValues` are combined into an `AccumulatorMetadata`, and is provided during registration. `SQLMetric` is a good example to show the simplicity of this new API. What we break: 1. no `setValue` anymore. In the new API, the intermedia type can be different from the result type, it's very hard to implement a general `setValue` 2. accumulator can't be serialized before registered. Problems need to be addressed in follow-ups: 1. with this new API, `AccumulatorInfo` doesn't make a lot of sense, the partial output is not partial updates, we need to expose the intermediate value. 2. `ExceptionFailure` should not carry the accumulator updates. Why do users care about accumulator updates for failed cases? It looks like we only use this feature to update the internal metrics, how about we sending a heartbeat to update internal metrics after the failure event? 3. the public event `SparkListenerTaskEnd` carries a `TaskMetrics`. Ideally this `TaskMetrics` don't need to carry external accumulators, as the only method of `TaskMetrics` that can access external accumulators is `private[spark]`. However, `SQLListener` use it to retrieve sql metrics. ## How was this patch tested? existing tests Author: Wenchen Fan <wenchen@databricks.com> Closes #12612 from cloud-fan/acc.
1 parent be317d4 commit bf5496d

File tree

73 files changed

+1071
-842
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+1071
-842
lines changed

core/src/main/scala/org/apache/spark/Accumulable.scala

Lines changed: 16 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class Accumulable[R, T] private (
6363
param: AccumulableParam[R, T],
6464
name: Option[String],
6565
countFailedValues: Boolean) = {
66-
this(Accumulators.newId(), initialValue, param, name, countFailedValues)
66+
this(AccumulatorContext.newId(), initialValue, param, name, countFailedValues)
6767
}
6868

6969
private[spark] def this(initialValue: R, param: AccumulableParam[R, T], name: Option[String]) = {
@@ -72,61 +72,44 @@ class Accumulable[R, T] private (
7272

7373
def this(initialValue: R, param: AccumulableParam[R, T]) = this(initialValue, param, None)
7474

75-
@volatile @transient private var value_ : R = initialValue // Current value on driver
76-
val zero = param.zero(initialValue) // Zero value to be passed to executors
77-
private var deserialized = false
78-
79-
Accumulators.register(this)
80-
81-
/**
82-
* Return a copy of this [[Accumulable]].
83-
*
84-
* The copy will have the same ID as the original and will not be registered with
85-
* [[Accumulators]] again. This method exists so that the caller can avoid passing the
86-
* same mutable instance around.
87-
*/
88-
private[spark] def copy(): Accumulable[R, T] = {
89-
new Accumulable[R, T](id, initialValue, param, name, countFailedValues)
90-
}
75+
val zero = param.zero(initialValue)
76+
private[spark] val newAcc = new LegacyAccumulatorWrapper(initialValue, param)
77+
newAcc.metadata = AccumulatorMetadata(id, name, countFailedValues)
78+
// Register the new accumulator in ctor, to follow the previous behaviour.
79+
AccumulatorContext.register(newAcc)
9180

9281
/**
9382
* Add more data to this accumulator / accumulable
9483
* @param term the data to add
9584
*/
96-
def += (term: T) { value_ = param.addAccumulator(value_, term) }
85+
def += (term: T) { newAcc.add(term) }
9786

9887
/**
9988
* Add more data to this accumulator / accumulable
10089
* @param term the data to add
10190
*/
102-
def add(term: T) { value_ = param.addAccumulator(value_, term) }
91+
def add(term: T) { newAcc.add(term) }
10392

10493
/**
10594
* Merge two accumulable objects together
10695
*
10796
* Normally, a user will not want to use this version, but will instead call `+=`.
10897
* @param term the other `R` that will get merged with this
10998
*/
110-
def ++= (term: R) { value_ = param.addInPlace(value_, term)}
99+
def ++= (term: R) { newAcc._value = param.addInPlace(newAcc._value, term) }
111100

112101
/**
113102
* Merge two accumulable objects together
114103
*
115104
* Normally, a user will not want to use this version, but will instead call `add`.
116105
* @param term the other `R` that will get merged with this
117106
*/
118-
def merge(term: R) { value_ = param.addInPlace(value_, term)}
107+
def merge(term: R) { newAcc._value = param.addInPlace(newAcc._value, term) }
119108

120109
/**
121110
* Access the accumulator's current value; only allowed on driver.
122111
*/
123-
def value: R = {
124-
if (!deserialized) {
125-
value_
126-
} else {
127-
throw new UnsupportedOperationException("Can't read accumulator value in task")
128-
}
129-
}
112+
def value: R = newAcc.value
130113

131114
/**
132115
* Get the current value of this accumulator from within a task.
@@ -137,14 +120,14 @@ class Accumulable[R, T] private (
137120
* The typical use of this method is to directly mutate the local value, eg., to add
138121
* an element to a Set.
139122
*/
140-
def localValue: R = value_
123+
def localValue: R = newAcc.localValue
141124

142125
/**
143126
* Set the accumulator's value; only allowed on driver.
144127
*/
145128
def value_= (newValue: R) {
146-
if (!deserialized) {
147-
value_ = newValue
129+
if (newAcc.isAtDriverSide) {
130+
newAcc._value = newValue
148131
} else {
149132
throw new UnsupportedOperationException("Can't assign accumulator value in task")
150133
}
@@ -153,7 +136,7 @@ class Accumulable[R, T] private (
153136
/**
154137
* Set the accumulator's value. For internal use only.
155138
*/
156-
def setValue(newValue: R): Unit = { value_ = newValue }
139+
def setValue(newValue: R): Unit = { newAcc._value = newValue }
157140

158141
/**
159142
* Set the accumulator's value. For internal use only.
@@ -168,22 +151,7 @@ class Accumulable[R, T] private (
168151
new AccumulableInfo(id, name, update, value, isInternal, countFailedValues)
169152
}
170153

171-
// Called by Java when deserializing an object
172-
private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
173-
in.defaultReadObject()
174-
value_ = zero
175-
deserialized = true
176-
177-
// Automatically register the accumulator when it is deserialized with the task closure.
178-
// This is for external accumulators and internal ones that do not represent task level
179-
// metrics, e.g. internal SQL metrics, which are per-operator.
180-
val taskContext = TaskContext.get()
181-
if (taskContext != null) {
182-
taskContext.registerAccumulator(this)
183-
}
184-
}
185-
186-
override def toString: String = if (value_ == null) "null" else value_.toString
154+
override def toString: String = if (newAcc._value == null) "null" else newAcc._value.toString
187155
}
188156

189157

core/src/main/scala/org/apache/spark/Accumulator.scala

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -68,73 +68,6 @@ class Accumulator[T] private[spark] (
6868
extends Accumulable[T, T](initialValue, param, name, countFailedValues)
6969

7070

71-
// TODO: The multi-thread support in accumulators is kind of lame; check
72-
// if there's a more intuitive way of doing it right
73-
private[spark] object Accumulators extends Logging {
74-
/**
75-
* This global map holds the original accumulator objects that are created on the driver.
76-
* It keeps weak references to these objects so that accumulators can be garbage-collected
77-
* once the RDDs and user-code that reference them are cleaned up.
78-
* TODO: Don't use a global map; these should be tied to a SparkContext (SPARK-13051).
79-
*/
80-
@GuardedBy("Accumulators")
81-
val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]()
82-
83-
private val nextId = new AtomicLong(0L)
84-
85-
/**
86-
* Return a globally unique ID for a new [[Accumulable]].
87-
* Note: Once you copy the [[Accumulable]] the ID is no longer unique.
88-
*/
89-
def newId(): Long = nextId.getAndIncrement
90-
91-
/**
92-
* Register an [[Accumulable]] created on the driver such that it can be used on the executors.
93-
*
94-
* All accumulators registered here can later be used as a container for accumulating partial
95-
* values across multiple tasks. This is what [[org.apache.spark.scheduler.DAGScheduler]] does.
96-
* Note: if an accumulator is registered here, it should also be registered with the active
97-
* context cleaner for cleanup so as to avoid memory leaks.
98-
*
99-
* If an [[Accumulable]] with the same ID was already registered, this does nothing instead
100-
* of overwriting it. This happens when we copy accumulators, e.g. when we reconstruct
101-
* [[org.apache.spark.executor.TaskMetrics]] from accumulator updates.
102-
*/
103-
def register(a: Accumulable[_, _]): Unit = synchronized {
104-
if (!originals.contains(a.id)) {
105-
originals(a.id) = new WeakReference[Accumulable[_, _]](a)
106-
}
107-
}
108-
109-
/**
110-
* Unregister the [[Accumulable]] with the given ID, if any.
111-
*/
112-
def remove(accId: Long): Unit = synchronized {
113-
originals.remove(accId)
114-
}
115-
116-
/**
117-
* Return the [[Accumulable]] registered with the given ID, if any.
118-
*/
119-
def get(id: Long): Option[Accumulable[_, _]] = synchronized {
120-
originals.get(id).map { weakRef =>
121-
// Since we are storing weak references, we must check whether the underlying data is valid.
122-
weakRef.get.getOrElse {
123-
throw new IllegalAccessError(s"Attempted to access garbage collected accumulator $id")
124-
}
125-
}
126-
}
127-
128-
/**
129-
* Clear all registered [[Accumulable]]s. For testing only.
130-
*/
131-
def clear(): Unit = synchronized {
132-
originals.clear()
133-
}
134-
135-
}
136-
137-
13871
/**
13972
* A simpler version of [[org.apache.spark.AccumulableParam]] where the only data type you can add
14073
* in is the same type as the accumulated value. An implicit AccumulatorParam object needs to be

core/src/main/scala/org/apache/spark/ContextCleaner.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
144144
registerForCleanup(rdd, CleanRDD(rdd.id))
145145
}
146146

147-
def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = {
147+
def registerAccumulatorForCleanup(a: NewAccumulator[_, _]): Unit = {
148148
registerForCleanup(a, CleanAccum(a.id))
149149
}
150150

@@ -241,7 +241,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
241241
def doCleanupAccum(accId: Long, blocking: Boolean): Unit = {
242242
try {
243243
logDebug("Cleaning accumulator " + accId)
244-
Accumulators.remove(accId)
244+
AccumulatorContext.remove(accId)
245245
listeners.asScala.foreach(_.accumCleaned(accId))
246246
logInfo("Cleaned accumulator " + accId)
247247
} catch {

core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils}
3535
*/
3636
private[spark] case class Heartbeat(
3737
executorId: String,
38-
accumUpdates: Array[(Long, Seq[AccumulableInfo])], // taskId -> accum updates
38+
accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])], // taskId -> accumulator updates
3939
blockManagerId: BlockManagerId)
4040

4141
/**

0 commit comments

Comments
 (0)