Skip to content

Commit 5264164

Browse files
dbkerkelahvanhovell
authored andcommitted
[SPARK-24648][SQL] SqlMetrics should be threadsafe
Use LongAdder to make SQLMetrics thread safe. ## What changes were proposed in this pull request? Replace += with LongAdder.add() for concurrent counting ## How was this patch tested? Unit tests with local threads Author: Stacy Kerkela <stacy.kerkela@databricks.com> Closes #21634 from dbkerkela/sqlmetrics-concurrency-stacy.
1 parent 594ac4f commit 5264164

File tree

2 files changed

+55
-14
lines changed

2 files changed

+55
-14
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.metric
1919

2020
import java.text.NumberFormat
2121
import java.util.Locale
22+
import java.util.concurrent.atomic.LongAdder
2223

2324
import org.apache.spark.SparkContext
2425
import org.apache.spark.scheduler.AccumulableInfo
@@ -32,40 +33,45 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils}
3233
* on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]].
3334
*/
3435
class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] {
36+
3537
// This is a workaround for SPARK-11013.
3638
// We may use -1 as initial value of the accumulator, if the accumulator is valid, we will
3739
// update it at the end of task and the value will be at least 0. Then we can filter out the -1
3840
// values before calculate max, min, etc.
39-
private[this] var _value = initValue
40-
private var _zeroValue = initValue
41+
private[this] val _value = new LongAdder
42+
private val _zeroValue = initValue
43+
_value.add(initValue)
4144

4245
override def copy(): SQLMetric = {
43-
val newAcc = new SQLMetric(metricType, _value)
44-
newAcc._zeroValue = initValue
46+
val newAcc = new SQLMetric(metricType, initValue)
47+
newAcc.add(_value.sum())
4548
newAcc
4649
}
4750

48-
override def reset(): Unit = _value = _zeroValue
51+
override def reset(): Unit = this.set(_zeroValue)
4952

5053
override def merge(other: AccumulatorV2[Long, Long]): Unit = other match {
51-
case o: SQLMetric => _value += o.value
54+
case o: SQLMetric => _value.add(o.value)
5255
case _ => throw new UnsupportedOperationException(
5356
s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
5457
}
5558

56-
override def isZero(): Boolean = _value == _zeroValue
59+
override def isZero(): Boolean = _value.sum() == _zeroValue
5760

58-
override def add(v: Long): Unit = _value += v
61+
override def add(v: Long): Unit = _value.add(v)
5962

6063
// We can set a double value to `SQLMetric` which stores only long value, if it is
6164
// average metrics.
6265
def set(v: Double): Unit = SQLMetrics.setDoubleForAverageMetrics(this, v)
6366

64-
def set(v: Long): Unit = _value = v
67+
def set(v: Long): Unit = {
68+
_value.reset()
69+
_value.add(v)
70+
}
6571

66-
def +=(v: Long): Unit = _value += v
72+
def +=(v: Long): Unit = _value.add(v)
6773

68-
override def value: Long = _value
74+
override def value: Long = _value.sum()
6975

7076
// Provide special identifier as metadata so we can tell that this is a `SQLMetric` later
7177
override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
@@ -153,7 +159,7 @@ object SQLMetrics {
153159
Seq.fill(3)(0L)
154160
} else {
155161
val sorted = validValues.sorted
156-
Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
162+
Seq(sorted.head, sorted(validValues.length / 2), sorted(validValues.length - 1))
157163
}
158164
metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric))
159165
}
@@ -173,7 +179,8 @@ object SQLMetrics {
173179
Seq.fill(4)(0L)
174180
} else {
175181
val sorted = validValues.sorted
176-
Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
182+
Seq(sorted.sum, sorted.head, sorted(validValues.length / 2),
183+
sorted(validValues.length - 1))
177184
}
178185
metric.map(strFormat)
179186
}

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ package org.apache.spark.sql.execution.metric
1919

2020
import java.io.File
2121

22+
import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future}
2223
import scala.util.Random
2324

2425
import org.apache.spark.SparkFunSuite
2526
import org.apache.spark.sql._
2627
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
27-
import org.apache.spark.sql.execution.ui.SQLAppStatusStore
2828
import org.apache.spark.sql.functions._
2929
import org.apache.spark.sql.internal.SQLConf
3030
import org.apache.spark.sql.test.SharedSQLContext
@@ -504,4 +504,38 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
504504
test("writing data out metrics with dynamic partition: parquet") {
505505
testMetricsDynamicPartition("parquet", "parquet", "t1")
506506
}
507+
508+
test("writing metrics from single thread") {
509+
val nAdds = 10
510+
val acc = new SQLMetric("test", -10)
511+
assert(acc.isZero())
512+
acc.set(0)
513+
for (i <- 1 to nAdds) acc.add(1)
514+
assert(!acc.isZero())
515+
assert(nAdds === acc.value)
516+
acc.reset()
517+
assert(acc.isZero())
518+
}
519+
520+
test("writing metrics from multiple threads") {
521+
implicit val ec: ExecutionContextExecutor = ExecutionContext.global
522+
val nFutures = 1000
523+
val nAdds = 100
524+
val acc = new SQLMetric("test", -10)
525+
assert(acc.isZero() === true)
526+
acc.set(0)
527+
val l = for ( i <- 1 to nFutures ) yield {
528+
Future {
529+
for (j <- 1 to nAdds) acc.add(1)
530+
i
531+
}
532+
}
533+
for (futures <- Future.sequence(l)) {
534+
assert(nFutures === futures.length)
535+
assert(!acc.isZero())
536+
assert(nFutures * nAdds === acc.value)
537+
acc.reset()
538+
assert(acc.isZero())
539+
}
540+
}
507541
}

0 commit comments

Comments
 (0)