Skip to content

Commit 4ff8912

Browse files
committed
handle -1 values
1 parent 778992e commit 4ff8912

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,19 @@ private[sql] trait SQLMetricValue[T] extends Serializable {
6767
private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] {
6868

6969
def add(incr: Long): LongSQLMetricValue = {
70-
_value += incr
70+
// Some LongSQLMetric will use -1 as initial value, so if the accumulator is never updated,
71+
// we can filter it out later. However, when `add` is called, the accumulator is valid, we
72+
// should turn -1 to 0.
73+
if (_value < 0) {
74+
_value = 0
75+
}
76+
77+
// Some LongSQLMetric will use -1 as initial value, when we merge accumulator updates at driver
78+
// side, we should ignore these -1 values.
79+
if (incr > 0) {
80+
_value += incr
81+
}
82+
7183
this
7284
}
7385

sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717

1818
package org.apache.spark.sql.util
1919

20-
import org.apache.spark.SparkException
20+
import scala.collection.mutable.ArrayBuffer
21+
22+
import org.apache.spark._
2123
import org.apache.spark.sql.{functions, QueryTest}
2224
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
2325
import org.apache.spark.sql.execution.QueryExecution
2426
import org.apache.spark.sql.test.SharedSQLContext
2527

26-
import scala.collection.mutable.ArrayBuffer
27-
2828
class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
2929
import testImplicits._
3030
import functions._
@@ -81,7 +81,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
8181
assert(metrics(0)._3.getMessage == e.getMessage)
8282
}
8383

84-
test("get metrics by callback") {
84+
test("get numRows metrics by callback") {
8585
val metrics = ArrayBuffer.empty[Long]
8686
val listener = new QueryExecutionListener {
8787
// Only test successful case here, so no need to implement `onFailure`
@@ -103,4 +103,43 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
103103
assert(metrics(1) == 1)
104104
assert(metrics(2) == 2)
105105
}
106+
107+
test("get size metrics by callback") {
108+
val metrics = ArrayBuffer.empty[Long]
109+
val listener = new QueryExecutionListener {
110+
// Only test successful case here, so no need to implement `onFailure`
111+
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
112+
113+
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
114+
metrics += qe.executedPlan.longMetric("dataSize").value.value
115+
val bottomAgg = qe.executedPlan.children(0).children(0)
116+
metrics += bottomAgg.longMetric("dataSize").value.value
117+
}
118+
}
119+
sqlContext.listenerManager.register(listener)
120+
121+
val sparkListener = new SaveInfoListener
122+
sqlContext.sparkContext.addSparkListener(sparkListener)
123+
124+
val df = (1 to 100).map(i => i -> i.toString).toDF("i", "j")
125+
df.groupBy("i").count().collect()
126+
127+
def getPeakExecutionMemory(stageId: Int): Long = {
128+
val peakMemoryAccumulator = sparkListener.getCompletedStageInfos(stageId).accumulables
129+
.filter(_._2.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)
130+
131+
assert(peakMemoryAccumulator.size == 1)
132+
peakMemoryAccumulator.head._2.value.toLong
133+
}
134+
135+
assert(sparkListener.getCompletedStageInfos.length == 2)
136+
val bottomAggDataSize = getPeakExecutionMemory(0)
137+
val topAggDataSize = getPeakExecutionMemory(1)
138+
139+
// For this simple case, the peakExecutionMemory of a stage should be the data size of the
140+
// aggregate operator, as we only have one memory consuming operator per stage.
141+
assert(metrics.length == 2)
142+
assert(metrics(0) == topAggDataSize)
143+
assert(metrics(1) == bottomAggDataSize)
144+
}
106145
}

0 commit comments

Comments
 (0)