diff --git a/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto b/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto index 386c660b16def..85e1b827b9376 100644 --- a/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto +++ b/core/src/main/protobuf/org/apache/spark/status/protobuf/store_types.proto @@ -407,6 +407,7 @@ message SQLPlanMetric { optional string name = 1; int64 accumulator_id = 2; optional string metric_type = 3; + int64 init_value = 4; } message SQLExecutionUIData { diff --git a/core/src/main/scala/org/apache/spark/util/MetricUtils.scala b/core/src/main/scala/org/apache/spark/util/MetricUtils.scala index a6166f2129d1b..0bff7ff1acbee 100644 --- a/core/src/main/scala/org/apache/spark/util/MetricUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/MetricUtils.scala @@ -48,7 +48,8 @@ private[spark] object MetricUtils { * A function that defines how we aggregate the final accumulator results among all tasks, * and represent it in string for a SQL physical operator. */ - def stringValue(metricsType: String, values: Array[Long], maxMetrics: Array[Long]): String = { + def stringValue(metricsType: String, initValue: Long, + values: Array[Long], maxMetrics: Array[Long]): String = { // taskInfo = "(driver)" OR (stage ${stageId}.${attemptId}: task $taskId) val taskInfo = if (maxMetrics.isEmpty) { "(driver)" @@ -59,7 +60,7 @@ private[spark] object MetricUtils { val numberFormat = NumberFormat.getIntegerInstance(Locale.US) numberFormat.format(values.sum) } else if (metricsType == AVERAGE_METRIC) { - val validValues = values.filter(_ > 0) + val validValues = values.filter(_ > initValue) // When there are only 1 metrics value (or None), no need to display max/min/median. This is // common for driver-side SQL metrics. if (validValues.length <= 1) { @@ -85,7 +86,7 @@ private[spark] object MetricUtils { throw SparkException.internalError(s"unexpected metrics type: $metricsType") } - val validValues = values.filter(_ >= 0) + val validValues = values.filter(_ > initValue) // When there are only 1 metrics value (or None), no need to display max/min/median. This is // common for driver-side SQL metrics. if (validValues.length <= 1) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 615c8746a3e52..c7243309acd61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -77,7 +77,7 @@ private[execution] object SparkPlanInfo { case _ => plan.children ++ plan.subqueries } val metrics = plan.metrics.toSeq.map { case (key, metric) => - new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType) + new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType, metric.initValue) } // dump the file scan metadata (e.g file path) to event log diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 07d215f8a186f..5c11dc1f7e1df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -847,7 +847,8 @@ case class AdaptiveSparkPlanExec( private def onUpdatePlan(executionId: Long, newSubPlans: Seq[SparkPlan]): Unit = { if (!shouldUpdatePlan) { val newMetrics = newSubPlans.flatMap { p => - p.flatMap(_.metrics.values.map(m => SQLPlanMetric(m.name.get, m.id, m.metricType))) + p.flatMap(_.metrics.values.map(m => + SQLPlanMetric(m.name.get, m.id, m.metricType, m.initValue))) } context.session.sparkContext.listenerBus.post(SparkListenerSQLAdaptiveSQLMetricUpdates( executionId, newMetrics)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala index 2db2ff74374ca..a24067bf29545 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala @@ -29,7 +29,7 @@ class PythonCustomMetric( def this() = this(null, null) override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { - MetricUtils.stringValue("size", taskMetrics, Array.empty[Long]) + MetricUtils.stringValue("size", 0L, taskMetrics, Array.empty[Long]) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala index adb81519dbc83..6c011c8762b07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala @@ -27,4 +27,5 @@ import org.apache.spark.annotation.DeveloperApi class SQLMetricInfo( val name: String, val accumulatorId: Long, - val metricType: String) + val metricType: String, + val initValue: Long) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 065c8db7ac6f9..3b47eb87ce7b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.AccumulatorContext.internOption */ class SQLMetric( val metricType: String, - initValue: Long = 0L) extends AccumulatorV2[Long, Long] { + val initValue: Long = 0L) extends AccumulatorV2[Long, Long] { // initValue defines the initial value of the metric. 0 is the lowest value considered valid. // If a SQLMetric is invalid, it is set to 0 upon receiving any updates, and it also reports // 0 as its value to avoid exposing it to the user programmatically. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index f680860231f01..b710c00bed375 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -235,7 +235,7 @@ class SQLAppStatusListener( } }.getOrElse( // Built-in SQLMetric - MetricUtils.stringValue(m.metricType, _, _) + MetricUtils.stringValue(m.metricType, m.initValue, _, _) ) (m.accumulatorId, metricAggMethod) }.toMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala index 8681bfb2342b4..6818e82be3015 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -171,4 +171,5 @@ class SparkPlanGraphNodeWrapper( case class SQLPlanMetric( name: String, accumulatorId: Long, - metricType: String) + metricType: String, + initValue: Long) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index ced4b6224c884..ec4b7554acaa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -81,7 +81,7 @@ object SparkPlanGraph { planInfo.nodeName match { case name if name.startsWith("WholeStageCodegen") => val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType) + SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType, metric.initValue) } val cluster = new SparkPlanGraphCluster( @@ -127,7 +127,7 @@ object SparkPlanGraph { edges += SparkPlanGraphEdge(node.id, parent.id) case name => val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType) + SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType, metric.initValue) } val node = new SparkPlanGraphNode( nodeIdGenerator.getAndIncrement(), planInfo.nodeName, diff --git a/sql/core/src/main/scala/org/apache/spark/status/protobuf/sql/SQLPlanMetricSerializer.scala b/sql/core/src/main/scala/org/apache/spark/status/protobuf/sql/SQLPlanMetricSerializer.scala index a0c15c3c322fd..73444607d850c 100644 --- a/sql/core/src/main/scala/org/apache/spark/status/protobuf/sql/SQLPlanMetricSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/status/protobuf/sql/SQLPlanMetricSerializer.scala @@ -29,6 +29,7 @@ private[protobuf] object SQLPlanMetricSerializer { setStringField(metric.name, builder.setName) builder.setAccumulatorId(metric.accumulatorId) setStringField(metric.metricType, builder.setMetricType) + builder.setInitValue(metric.initValue) builder.build() } @@ -36,7 +37,8 @@ private[protobuf] object SQLPlanMetricSerializer { SQLPlanMetric( name = getStringField(metrics.hasName, () => weakIntern(metrics.getName)), accumulatorId = metrics.getAccumulatorId, - metricType = getStringField(metrics.hasMetricType, () => weakIntern(metrics.getMetricType)) + metricType = getStringField(metrics.hasMetricType, () => weakIntern(metrics.getMetricType)), + initValue = metrics.getInitValue ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala index b194ce8f84887..2e0765579efd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala @@ -61,7 +61,7 @@ object MetricsAggregationBenchmark extends BenchmarkBase { val store = new SQLAppStatusStore(kvstore, Some(listener)) val metrics = (0 until numMetrics).map { i => - new SQLMetricInfo(s"metric$i", i.toLong, "average") + new SQLMetricInfo(s"metric$i", i.toLong, "average", 0L) } val planInfo = new SparkPlanInfo( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 800a58f0c1d63..8e369b7c14849 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -597,9 +597,9 @@ abstract class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTes val metrics = statusStore.executionMetrics(execId) val driverMetric = physicalPlan.metrics("dummy") val driverMetric2 = physicalPlan.metrics("dummy2") - val expectedValue = MetricUtils.stringValue(driverMetric.metricType, + val expectedValue = MetricUtils.stringValue(driverMetric.metricType, driverMetric.initValue, Array(expectedAccumValue), Array.empty[Long]) - val expectedValue2 = MetricUtils.stringValue(driverMetric2.metricType, + val expectedValue2 = MetricUtils.stringValue(driverMetric2.metricType, driverMetric2.initValue, Array(expectedAccumValue2), Array.empty[Long]) assert(metrics.contains(driverMetric.id)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SparkPlanGraphSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SparkPlanGraphSuite.scala index 975dbc1a1d8df..e7befd5c46ab4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SparkPlanGraphSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SparkPlanGraphSuite.scala @@ -29,12 +29,14 @@ class SparkPlanGraphSuite extends SparkFunSuite { SQLPlanMetric( name = "number of output rows", accumulatorId = 75, - metricType = "sum" + metricType = "sum", + initValue = 0L ), SQLPlanMetric( name = "JDBC query execution time", accumulatorId = 35, - metricType = "nsTiming"))) + metricType = "nsTiming", + initValue = -1L))) val dotNode = planGraphNode.makeDotNode(Map.empty[Long, String]) val expectedDotNode = " 24 [id=\"node24\" labelType=\"html\" label=\"" + "
Scan JDBCRelation(\\\"test-schema\\\".tickets) [numPartitions=1]

\" " + diff --git a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala index c5e2e657de8cb..6e2ff48f7aa37 100644 --- a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala @@ -42,17 +42,17 @@ object SqlResourceSuite { val nodeIdAndWSCGIdMap: Map[Long, Option[Long]] = Map(1L -> Some(1L)) val filterNode = new SparkPlanGraphNode(1, FILTER, "", - metrics = Seq(SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 1, ""))) + metrics = Seq(SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 1, "", 0L))) val nodes: Seq[SparkPlanGraphNode] = Seq( new SparkPlanGraphCluster(0, WHOLE_STAGE_CODEGEN_1, "", nodes = ArrayBuffer(filterNode), - metrics = Seq(SQLPlanMetric(DURATION, 0, ""))), + metrics = Seq(SQLPlanMetric(DURATION, 0, "", 0L))), new SparkPlanGraphNode(2, SCAN_TEXT, "", metrics = Seq( - SQLPlanMetric(METADATA_TIME, 2, ""), - SQLPlanMetric(NUMBER_OF_FILES_READ, 3, ""), - SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 4, ""), - SQLPlanMetric(SIZE_OF_FILES_READ, 5, "")))) + SQLPlanMetric(METADATA_TIME, 2, "", 0L), + SQLPlanMetric(NUMBER_OF_FILES_READ, 3, "", 0L), + SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 4, "", 0L), + SQLPlanMetric(SIZE_OF_FILES_READ, 5, "", 0L)))) val edges: Seq[SparkPlanGraphEdge] = Seq(SparkPlanGraphEdge(3, 2)) @@ -60,12 +60,12 @@ object SqlResourceSuite { SparkPlanGraph(nodes, edges).allNodes.filterNot(_.name == WHOLE_STAGE_CODEGEN_1) val metrics: Seq[SQLPlanMetric] = { - Seq(SQLPlanMetric(DURATION, 0, ""), - SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 1, ""), - SQLPlanMetric(METADATA_TIME, 2, ""), - SQLPlanMetric(NUMBER_OF_FILES_READ, 3, ""), - SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 4, ""), - SQLPlanMetric(SIZE_OF_FILES_READ, 5, "")) + Seq(SQLPlanMetric(DURATION, 0, "", 0L), + SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 1, "", 0L), + SQLPlanMetric(METADATA_TIME, 2, "", 0L), + SQLPlanMetric(NUMBER_OF_FILES_READ, 3, "", 0L), + SQLPlanMetric(NUMBER_OF_OUTPUT_ROWS, 4, "", 0L), + SQLPlanMetric(SIZE_OF_FILES_READ, 5, "", 0L)) } private def getMetricValues() = { diff --git a/sql/core/src/test/scala/org/apache/spark/status/protobuf/sql/KVStoreProtobufSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/protobuf/sql/KVStoreProtobufSerializerSuite.scala index 3f3a6925409cd..2cbad203e0e6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/status/protobuf/sql/KVStoreProtobufSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/status/protobuf/sql/KVStoreProtobufSerializerSuite.scala @@ -45,7 +45,7 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite { details = null, physicalPlanDescription = null, modifiedConfigs = normal.modifiedConfigs, - metrics = Seq(SQLPlanMetric(null, 0, null)), + metrics = Seq(SQLPlanMetric(null, 0, null, 0L)), submissionTime = normal.submissionTime, completionTime = normal.completionTime, errorMessage = normal.errorMessage, @@ -126,12 +126,14 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite { SQLPlanMetric( name = "name_13", accumulatorId = 13, - metricType = "metric_13" + metricType = "metric_13", + initValue = 0L ), SQLPlanMetric( name = "name_14", accumulatorId = 14, - metricType = "metric_14" + metricType = "metric_14", + initValue = 0L ) ) ), @@ -147,7 +149,8 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite { SQLPlanMetric( name = null, accumulatorId = 13, - metricType = null + metricType = null, + initValue = 0L ) ) ), @@ -174,12 +177,14 @@ class KVStoreProtobufSerializerSuite extends SparkFunSuite { SQLPlanMetric( name = "name_6", accumulatorId = 6, - metricType = "metric_6" + metricType = "metric_6", + initValue = 0L ), SQLPlanMetric( name = "name_7 d", accumulatorId = 7, - metricType = "metric_7" + metricType = "metric_7", + initValue = 0L ) ) )