diff --git a/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto b/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto
index 386c660b16def..85e1b827b9376 100644
--- a/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto
+++ b/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto
@@ -407,6 +407,7 @@ message SQLPlanMetric {
optional string name = 1;
int64 accumulator_id = 2;
optional string metric_type = 3;
+ int64 init_value = 4;
}
message SQLExecutionUIData {
diff --git a/core/src/main/scala/org/apache/spark/util/MetricUtils.scala b/core/src/main/scala/org/apache/spark/util/MetricUtils.scala
index a6166f2129d1b..0bff7ff1acbee 100644
--- a/core/src/main/scala/org/apache/spark/util/MetricUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetricUtils.scala
@@ -48,7 +48,8 @@ private[spark] object MetricUtils {
* A function that defines how we aggregate the final accumulator results among all tasks,
* and represent it in string for a SQL physical operator.
*/
- def stringValue(metricsType: String, values: Array[Long], maxMetrics: Array[Long]): String = {
+ def stringValue(metricsType: String, initValue: Long,
+ values: Array[Long], maxMetrics: Array[Long]): String = {
// taskInfo = "(driver)" OR (stage ${stageId}.${attemptId}: task $taskId)
val taskInfo = if (maxMetrics.isEmpty) {
"(driver)"
@@ -59,7 +60,7 @@ private[spark] object MetricUtils {
val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
numberFormat.format(values.sum)
} else if (metricsType == AVERAGE_METRIC) {
- val validValues = values.filter(_ > 0)
+ val validValues = values.filter(_ > initValue)
// When there are only 1 metrics value (or None), no need to display max/min/median. This is
// common for driver-side SQL metrics.
if (validValues.length <= 1) {
@@ -85,7 +86,7 @@ private[spark] object MetricUtils {
throw SparkException.internalError(s"unexpected metrics type: $metricsType")
}
- val validValues = values.filter(_ >= 0)
+ val validValues = values.filter(_ > initValue)
// When there are only 1 metrics value (or None), no need to display max/min/median. This is
// common for driver-side SQL metrics.
if (validValues.length <= 1) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index 615c8746a3e52..c7243309acd61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -77,7 +77,7 @@ private[execution] object SparkPlanInfo {
case _ => plan.children ++ plan.subqueries
}
val metrics = plan.metrics.toSeq.map { case (key, metric) =>
- new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType)
+ new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType, metric.initValue)
}
// dump the file scan metadata (e.g file path) to event log
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 07d215f8a186f..5c11dc1f7e1df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -847,7 +847,8 @@ case class AdaptiveSparkPlanExec(
private def onUpdatePlan(executionId: Long, newSubPlans: Seq[SparkPlan]): Unit = {
if (!shouldUpdatePlan) {
val newMetrics = newSubPlans.flatMap { p =>
- p.flatMap(_.metrics.values.map(m => SQLPlanMetric(m.name.get, m.id, m.metricType)))
+ p.flatMap(_.metrics.values.map(m =>
+ SQLPlanMetric(m.name.get, m.id, m.metricType, m.initValue)))
}
context.session.sparkContext.listenerBus.post(SparkListenerSQLAdaptiveSQLMetricUpdates(
executionId, newMetrics))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala
index 2db2ff74374ca..a24067bf29545 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala
@@ -29,7 +29,7 @@ class PythonCustomMetric(
def this() = this(null, null)
override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
- MetricUtils.stringValue("size", taskMetrics, Array.empty[Long])
+ MetricUtils.stringValue("size", 0L, taskMetrics, Array.empty[Long])
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala
index adb81519dbc83..6c011c8762b07 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala
@@ -27,4 +27,5 @@ import org.apache.spark.annotation.DeveloperApi
class SQLMetricInfo(
val name: String,
val accumulatorId: Long,
- val metricType: String)
+ val metricType: String,
+ val initValue: Long)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 065c8db7ac6f9..3b47eb87ce7b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -34,7 +34,7 @@ import org.apache.spark.util.AccumulatorContext.internOption
*/
class SQLMetric(
val metricType: String,
- initValue: Long = 0L) extends AccumulatorV2[Long, Long] {
+ val initValue: Long = 0L) extends AccumulatorV2[Long, Long] {
// initValue defines the initial value of the metric. 0 is the lowest value considered valid.
// If a SQLMetric is invalid, it is set to 0 upon receiving any updates, and it also reports
// 0 as its value to avoid exposing it to the user programmatically.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
index f680860231f01..b710c00bed375 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
@@ -235,7 +235,7 @@ class SQLAppStatusListener(
}
}.getOrElse(
// Built-in SQLMetric
- MetricUtils.stringValue(m.metricType, _, _)
+ MetricUtils.stringValue(m.metricType, m.initValue, _, _)
)
(m.accumulatorId, metricAggMethod)
}.toMap
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala
index 8681bfb2342b4..6818e82be3015 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala
@@ -171,4 +171,5 @@ class SparkPlanGraphNodeWrapper(
case class SQLPlanMetric(
name: String,
accumulatorId: Long,
- metricType: String)
+ metricType: String,
+ initValue: Long)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index ced4b6224c884..ec4b7554acaa7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -81,7 +81,7 @@ object SparkPlanGraph {
planInfo.nodeName match {
case name if name.startsWith("WholeStageCodegen") =>
val metrics = planInfo.metrics.map { metric =>
- SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
+ SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType, metric.initValue)
}
val cluster = new SparkPlanGraphCluster(
@@ -127,7 +127,7 @@ object SparkPlanGraph {
edges += SparkPlanGraphEdge(node.id, parent.id)
case name =>
val metrics = planInfo.metrics.map { metric =>
- SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
+ SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType, metric.initValue)
}
val node = new SparkPlanGraphNode(
nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
diff --git a/sql/core/src/main/scala/org/apache/spark/status/protobuf/sql/SQLPlanMetricSerializer.scala b/sql/core/src/main/scala/org/apache/spark/status/protobuf/sql/SQLPlanMetricSerializer.scala
index a0c15c3c322fd..73444607d850c 100644
--- a/sql/core/src/main/scala/org/apache/spark/status/protobuf/sql/SQLPlanMetricSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/status/protobuf/sql/SQLPlanMetricSerializer.scala
@@ -29,6 +29,7 @@ private[protobuf] object SQLPlanMetricSerializer {
setStringField(metric.name, builder.setName)
builder.setAccumulatorId(metric.accumulatorId)
setStringField(metric.metricType, builder.setMetricType)
+ builder.setInitValue(metric.initValue)
builder.build()
}
@@ -36,7 +37,8 @@ private[protobuf] object SQLPlanMetricSerializer {
SQLPlanMetric(
name = getStringField(metrics.hasName, () => weakIntern(metrics.getName)),
accumulatorId = metrics.getAccumulatorId,
- metricType = getStringField(metrics.hasMetricType, () => weakIntern(metrics.getMetricType))
+ metricType = getStringField(metrics.hasMetricType, () => weakIntern(metrics.getMetricType)),
+ initValue = metrics.getInitValue
)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala
index b194ce8f84887..2e0765579efd6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala
@@ -61,7 +61,7 @@ object MetricsAggregationBenchmark extends BenchmarkBase {
val store = new SQLAppStatusStore(kvstore, Some(listener))
val metrics = (0 until numMetrics).map { i =>
- new SQLMetricInfo(s"metric$i", i.toLong, "average")
+ new SQLMetricInfo(s"metric$i", i.toLong, "average", 0L)
}
val planInfo = new SparkPlanInfo(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
index 800a58f0c1d63..8e369b7c14849 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
@@ -597,9 +597,9 @@ abstract class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTes
val metrics = statusStore.executionMetrics(execId)
val driverMetric = physicalPlan.metrics("dummy")
val driverMetric2 = physicalPlan.metrics("dummy2")
- val expectedValue = MetricUtils.stringValue(driverMetric.metricType,
+ val expectedValue = MetricUtils.stringValue(driverMetric.metricType, driverMetric.initValue,
Array(expectedAccumValue), Array.empty[Long])
- val expectedValue2 = MetricUtils.stringValue(driverMetric2.metricType,
+ val expectedValue2 = MetricUtils.stringValue(driverMetric2.metricType, driverMetric2.initValue,
Array(expectedAccumValue2), Array.empty[Long])
assert(metrics.contains(driverMetric.id))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SparkPlanGraphSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SparkPlanGraphSuite.scala
index 975dbc1a1d8df..e7befd5c46ab4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SparkPlanGraphSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SparkPlanGraphSuite.scala
@@ -29,12 +29,14 @@ class SparkPlanGraphSuite extends SparkFunSuite {
SQLPlanMetric(
name = "number of output rows",
accumulatorId = 75,
- metricType = "sum"
+ metricType = "sum",
+ initValue = 0L
),
SQLPlanMetric(
name = "JDBC query execution time",
accumulatorId = 35,
- metricType = "nsTiming")))
+ metricType = "nsTiming",
+ initValue = -1L)))
val dotNode = planGraphNode.makeDotNode(Map.empty[Long, String])
val expectedDotNode = " 24 [id=\"node24\" labelType=\"html\" label=\"" +
"
Scan JDBCRelation(\\\"test-schema\\\".tickets) [numPartitions=1]
\" " +
diff --git a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala
index c5e2e657de8cb..6e2ff48f7aa37 100644
--- a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala
@@ -42,17 +42,17 @@ object SqlResourceSuite {
val nodeIdAndWSCGIdMap: Map[Long, Option[Long]] = Map(1L -> Some(1L))
val filterNode = new SparkPlanGraphNode(1, FILTER, "",
- metrics = Seq(SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 1, "")))
+ metrics = Seq(SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 1, "", 0L)))
val nodes: Seq[SparkPlanGraphNode] = Seq(
new SparkPlanGraphCluster(0, WHOLE_STAGE_CODEGEN_1, "",
nodes = ArrayBuffer(filterNode),
- metrics = Seq(SQLPlanMetric(DURATION, 0, ""))),
+ metrics = Seq(SQLPlanMetric(DURATION, 0, "", 0L))),
new SparkPlanGraphNode(2, SCAN_TEXT, "",
metrics = Seq(
- SQLPlanMetric(METADATA_TIME, 2, ""),
- SQLPlanMetric(NUMBER_OF_FILES_READ, 3, ""),
- SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 4, ""),
- SQLPlanMetric(SIZE_OF_FILES_READ, 5, ""))))
+ SQLPlanMetric(METADATA_TIME, 2, "", 0L),
+ SQLPlanMetric(NUMBER_OF_FILES_READ, 3, "", 0L),
+ SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 4, "", 0L),
+ SQLPlanMetric(SIZE_OF_FILES_READ, 5, "", 0L))))
val edges: Seq[SparkPlanGraphEdge] = Seq(SparkPlanGraphEdge(3, 2))
@@ -60,12 +60,12 @@ object SqlResourceSuite {
SparkPlanGraph(nodes, edges).allNodes.filterNot(_.name == WHOLE_STAGE_CODEGEN_1)
val metrics: Seq[SQLPlanMetric] = {
- Seq(SQLPlanMetric(DURATION, 0, ""),
- SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 1, ""),
- SQLPlanMetric(METADATA_TIME, 2, ""),
- SQLPlanMetric(NUMBER_OF_FILES_READ, 3, ""),
- SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 4, ""),
- SQLPlanMetric(SIZE_OF_FILES_READ, 5, ""))
+ Seq(SQLPlanMetric(DURATION, 0, "", 0L),
+ SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 1, "", 0L),
+ SQLPlanMetric(METADATA_TIME, 2, "", 0L),
+ SQLPlanMetric(NUMBER_OF_FILES_READ, 3, "", 0L),
+ SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 4, "", 0L),
+ SQLPlanMetric(SIZE_OF_FILES_READ, 5, "", 0L))
}
private def getMetricValues() = {
diff --git a/sql/core/src/test/scala/org/apache/spark/status/protobuf/sql/KVStoreProtobufSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/protobuf/sql/KVStoreProtobufSerializerSuite.scala
index 3f3a6925409cd..2cbad203e0e6f 100644
--- a/sql/core/src/test/scala/org/apache/spark/status/protobuf/sql/KVStoreProtobufSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/status/protobuf/sql/KVStoreProtobufSerializerSuite.scala
@@ -45,7 +45,7 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite {
details = null,
physicalPlanDescription = null,
modifiedConfigs = normal.modifiedConfigs,
- metrics = Seq(SQLPlanMetric(null, 0, null)),
+ metrics = Seq(SQLPlanMetric(null, 0, null, 0L)),
submissionTime = normal.submissionTime,
completionTime = normal.completionTime,
errorMessage = normal.errorMessage,
@@ -126,12 +126,14 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite {
SQLPlanMetric(
name = "name_13",
accumulatorId = 13,
- metricType = "metric_13"
+ metricType = "metric_13",
+ initValue = 0L
),
SQLPlanMetric(
name = "name_14",
accumulatorId = 14,
- metricType = "metric_14"
+ metricType = "metric_14",
+ initValue = 0L
)
)
),
@@ -147,7 +149,8 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite {
SQLPlanMetric(
name = null,
accumulatorId = 13,
- metricType = null
+ metricType = null,
+ initValue = 0L
)
)
),
@@ -174,12 +177,14 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite {
SQLPlanMetric(
name = "name_6",
accumulatorId = 6,
- metricType = "metric_6"
+ metricType = "metric_6",
+ initValue = 0L
),
SQLPlanMetric(
name = "name_7 d",
accumulatorId = 7,
- metricType = "metric_7"
+ metricType = "metric_7",
+ initValue = 0L
)
)
)