From ee4dd64bdccc33185e25e44cb0c80d44790c156b Mon Sep 17 00:00:00 2001 From: zouxxyy Date: Mon, 12 Aug 2024 10:23:52 +0800 Subject: [PATCH 1/3] [SPARK-49210][SQL] Support driver metrics for DS v2 write --- .../metric/SupportCustomMetrics.java | 40 ++++++++++++++++ .../apache/spark/sql/connector/read/Scan.java | 24 +--------- .../spark/sql/connector/write/Write.java | 12 +---- .../v2/DataSourceV2ScanExecBase.scala | 22 +++------ .../datasources/v2/V1FallbackWriters.scala | 7 +++ .../v2/WriteToDataSourceV2Exec.scala | 7 ++- .../sql/execution/metric/SQLMetrics.scala | 23 ++++++++- .../sql/connector/V1WriteFallbackSuite.scala | 44 +++++++++++++++++ .../ui/SQLAppStatusListenerSuite.scala | 48 ++++++++++++++++++- 9 files changed, 173 insertions(+), 54 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/SupportCustomMetrics.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/SupportCustomMetrics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/SupportCustomMetrics.java new file mode 100644 index 0000000000000..7fa5179641916 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/SupportCustomMetrics.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.metric; + +public interface SupportCustomMetrics { + + /** + * Returns an array of supported custom metrics with name and description. + * By default, it returns empty array. + */ + default CustomMetric[] supportedCustomMetrics() { + return new CustomMetric[]{}; + } + + /** + * Returns an array of custom metrics which are collected with values at the driver side only. + * Note that these metrics must be included in the supported custom metrics reported by + * `supportedCustomMetrics`. + * + * @since 3.4.0 + */ + default CustomTaskMetric[] reportDriverMetrics() { + return new CustomTaskMetric[]{}; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java index 81b89e5750d83..59d5112a5e251 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java @@ -21,8 +21,7 @@ import org.apache.spark.SparkUnsupportedOperationException; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.metric.CustomMetric; -import org.apache.spark.sql.connector.metric.CustomTaskMetric; +import org.apache.spark.sql.connector.metric.SupportCustomMetrics; import org.apache.spark.sql.connector.read.streaming.ContinuousStream; import org.apache.spark.sql.connector.read.streaming.MicroBatchStream; import org.apache.spark.sql.types.StructType; @@ -43,7 +42,7 @@ * @since 3.0.0 */ @Evolving -public interface Scan { +public interface Scan extends SupportCustomMetrics { /** * Returns the actual schema of this data source scan, which may be different from the physical @@ -115,25 +114,6 @@ default ContinuousStream toContinuousStream(String checkpointLocation) { "_LEGACY_ERROR_TEMP_3149", Map.of("description", description())); } - /** - * Returns an array of supported custom metrics with name and description. - * By default it returns empty array. - */ - default CustomMetric[] supportedCustomMetrics() { - return new CustomMetric[]{}; - } - - /** - * Returns an array of custom metrics which are collected with values at the driver side only. - * Note that these metrics must be included in the supported custom metrics reported by - * `supportedCustomMetrics`. - * - * @since 3.4.0 - */ - default CustomTaskMetric[] reportDriverMetrics() { - return new CustomTaskMetric[]{}; - } - /** * This enum defines how the columnar support for the partitions of the data source * should be determined. The default value is `PARTITION_DEFINED` which indicates that each diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/Write.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/Write.java index dc5712e93f470..9702d40e065ae 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/Write.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/Write.java @@ -23,7 +23,7 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCapability; -import org.apache.spark.sql.connector.metric.CustomMetric; +import org.apache.spark.sql.connector.metric.SupportCustomMetrics; import org.apache.spark.sql.connector.write.streaming.StreamingWrite; /** @@ -38,7 +38,7 @@ * @since 3.2.0 */ @Evolving -public interface Write { +public interface Write extends SupportCustomMetrics { /** * Returns the description associated with this write. @@ -68,12 +68,4 @@ default StreamingWrite toStreaming() { throw new SparkUnsupportedOperationException( "_LEGACY_ERROR_TEMP_3138", Map.of("description", description())); } - - /** - * Returns an array of supported custom metrics with name and description. - * By default it returns empty array. - */ - default CustomMetric[] supportedCustomMetrics() { - return new CustomMetric[]{}; - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index 95d85dab5cedc..b29643b5cfdb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -24,21 +24,19 @@ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan} -import org.apache.spark.sql.execution.{ExplainUtils, LeafExecNode, SQLExecution} -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.{ExplainUtils, LeafExecNode} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.SupportsMetadata import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils trait DataSourceV2ScanExecBase extends LeafExecNode { - lazy val customMetrics = scan.supportedCustomMetrics().map { customMetric => - customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric) - }.toMap + lazy val customMetrics: Map[String, SQLMetric] = + SQLMetrics.createV2CustomMetrics(sparkContext, scan) - override lazy val metrics = { + override lazy val metrics: Map[String, SQLMetric] = { Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) ++ customMetrics } @@ -191,15 +189,7 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { } protected def postDriverMetrics(): Unit = { - val driveSQLMetrics = scan.reportDriverMetrics().map(customTaskMetric => { - val metric = metrics(customTaskMetric.name()) - metric.set(customTaskMetric.value()) - metric - }) - - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, - driveSQLMetrics.toImmutableArraySeq) + SQLMetrics.postV2DriverMetrics(sparkContext, scan, metrics) } override def doExecuteColumnar(): RDD[ColumnarBatch] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala index 6f83b82785955..9726620f389e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.SupportsWrite import org.apache.spark.sql.connector.write.V1Write import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.sources.InsertableRelation /** @@ -62,9 +63,15 @@ sealed trait V1FallbackWriters extends LeafV2CommandExec with SupportsV1Write { def refreshCache: () => Unit def write: V1Write + protected val customMetrics: Map[String, SQLMetric] = + SQLMetrics.createV2CustomMetrics(sparkContext, write) + + override lazy val metrics: Map[String, SQLMetric] = customMetrics + override def run(): Seq[InternalRow] = { val writtenRows = writeWithV1(write.toInsertableRelation) refreshCache() + SQLMetrics.postV2DriverMetrics(sparkContext, write, metrics) writtenRows } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 5632595de7cf8..5e7ff1539fe59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -330,13 +330,12 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec { def write: Write override val customMetrics: Map[String, SQLMetric] = - write.supportedCustomMetrics().map { customMetric => - customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric) - }.toMap + SQLMetrics.createV2CustomMetrics(sparkContext, write) override protected def run(): Seq[InternalRow] = { val writtenRows = writeWithV2(write.toBatch) refreshCache() + SQLMetrics.postV2DriverMetrics(sparkContext, write, metrics) writtenRows } } @@ -355,7 +354,7 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode { protected val customMetrics: Map[String, SQLMetric] = Map.empty - override lazy val metrics = customMetrics + override lazy val metrics: Map[String, SQLMetric] = customMetrics protected def writeWithV2(batchWrite: BatchWrite): Seq[InternalRow] = { val rdd: RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index a246b47fe655a..9b5bf63553422 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -26,11 +26,13 @@ import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.scheduler.AccumulableInfo -import org.apache.spark.sql.connector.metric.CustomMetric +import org.apache.spark.sql.connector.metric.{CustomMetric, SupportCustomMetrics} import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} import org.apache.spark.util.AccumulatorContext.internOption +import org.apache.spark.util.ArrayImplicits._ /** * A metric used in a SQL query plan. This is implemented as an [[AccumulatorV2]]. Updates on @@ -152,6 +154,25 @@ object SQLMetrics { acc } + def createV2CustomMetrics( + sc: SparkContext, supportCustomMetrics: SupportCustomMetrics): Map[String, SQLMetric] = { + supportCustomMetrics.supportedCustomMetrics().map { customMetric => + customMetric.name() -> SQLMetrics.createV2CustomMetric(sc, customMetric) + }.toMap + } + + def postV2DriverMetrics(sc: SparkContext, supportCustomMetrics: SupportCustomMetrics, + metrics: Map[String, SQLMetric]): Unit = { + val driveSQLMetrics = supportCustomMetrics.reportDriverMetrics().map(customTaskMetric => { + val metric = metrics(customTaskMetric.name()) + metric.set(customTaskMetric.value()) + metric + }) + val executionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sc, executionId, + driveSQLMetrics.toImmutableArraySeq) + } + /** * Create a metric to report the size information (including total, min, med, max) like data size, * spill size, etc. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index ad31cf84eeb3f..278ec3b63c180 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, SupportsRead, SupportsWrite, Table, TableCapability} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan} import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1Write, WriteBuilder} import org.apache.spark.sql.execution.datasources.DataSourceUtils @@ -198,6 +199,27 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before SparkSession.setDefaultSession(spark) } } + + test("SPARK-49210: report driver metrics from fallback write") { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + try { + val session = SparkSession.builder() + .master("local[1]") + .config(V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[V1FallbackTableCatalog].getName) + .getOrCreate() + val df = session.createDataFrame(Seq((1, "p1"), (2, "p2"), (3, "p2"))) + df.write.partitionBy("_2").mode("append").format(v2Format).saveAsTable("test") + val statusStore = session.sharedState.statusStore + val execId = statusStore.executionsList() + .find(x => x.physicalPlanDescription.contains("AppendData")).get.executionId + val metrics = statusStore.executionMetrics(execId) + assert(metrics.head._2 == "2") + } finally { + SparkSession.setActiveSession(spark) + SparkSession.setDefaultSession(spark) + } + } } class V1WriteFallbackSessionCatalogSuite @@ -376,6 +398,8 @@ class InMemoryTableWithV1Fallback( } override def build(): V1Write = new V1Write { + var writtenPartitionsCount: Long = 0L + override def toInsertableRelation: InsertableRelation = { (data: DataFrame, overwrite: Boolean) => { assert(!overwrite, "V1 write fallbacks cannot be called with overwrite=true") @@ -389,8 +413,17 @@ class InMemoryTableWithV1Fallback( dataMap.put(partition, elements.toImmutableArraySeq) } } + writtenPartitionsCount = dataMap.size } } + + override def supportedCustomMetrics(): Array[CustomMetric] = { + Array(new WrittenPartitionDriverMetric) + } + + override def reportDriverMetrics(): Array[CustomTaskMetric] = { + Array(new WrittenPartitionDriverTaskMetric(writtenPartitionsCount)) + } } } @@ -452,3 +485,14 @@ object OnlyOnceOptimizerRule extends Rule[LogicalPlan] { } } } + +class WrittenPartitionDriverMetric extends CustomMetric { + override def name(): String = "number_of_written_partitions" + override def description(): String = "Simple custom driver metrics: number of written partitions" + override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = taskMetrics.sum.toString +} + +class WrittenPartitionDriverTaskMetric(value : Long) extends CustomTaskMetric { + override def name(): String = "number_of_written_partitions" + override def value(): Long = value +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index e63ff019a2b6c..1a874d767336b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -945,6 +945,34 @@ abstract class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTes } } + test("SPARK-49210: Report driver metrics from Datasource v2 write") { + withTempDir { dir => + val statusStore = spark.sharedState.statusStore + val oldCount = statusStore.executionsList().size + + val cls = classOf[CustomMetricsDataSource].getName + spark.range(10).select($"id" as Symbol("i"), -$"id" as Symbol("j")) + .write.format(cls) + .option("path", dir.getCanonicalPath).mode("append").save() + + // Wait until the new execution is started and being tracked. + eventually(timeout(10.seconds), interval(10.milliseconds)) { + assert(statusStore.executionsCount() >= oldCount) + } + + // Wait for listener to finish computing the metrics for the execution. + eventually(timeout(10.seconds), interval(10.milliseconds)) { + assert(statusStore.executionsList().nonEmpty && + statusStore.executionsList().last.metricValues != null) + } + + val execId = statusStore.executionsList().last.executionId + val metrics = statusStore.executionMetrics(execId) + val driverMetric = metrics.find(_._2 == "11111") + assert(driverMetric.isDefined) + } + } + test("SPARK-37578: Update output metrics from Datasource v2") { withTempDir { dir => val statusStore = spark.sharedState.statusStore @@ -1157,6 +1185,19 @@ class SimpleCustomDriverTaskMetric(value : Long) extends CustomTaskMetric { override def value(): Long = value } +class SimpleWriterCustomDriverMetric extends CustomMetric { + override def name(): String = "custom_writer_driver_metric" + override def description(): String = "Simple custom driver metrics - custom metric" + override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { + taskMetrics.sum.toString + } +} + +class SimpleWriterCustomDriverTaskMetric(value : Long) extends CustomTaskMetric { + override def name(): String = "custom_writer_driver_metric" + override def value(): Long = value +} + class BytesWrittenCustomMetric extends CustomMetric { override def name(): String = "bytesWritten" override def description(): String = "bytesWritten metric" @@ -1299,7 +1340,12 @@ class CustomMetricsDataSource extends SimpleWritableDataSource { override def supportedCustomMetrics(): Array[CustomMetric] = { Array(new SimpleCustomMetric, new Outer.InnerCustomMetric, - new BytesWrittenCustomMetric, new RecordsWrittenCustomMetric) + new BytesWrittenCustomMetric, new RecordsWrittenCustomMetric, + new SimpleWriterCustomDriverMetric) + } + + override def reportDriverMetrics(): Array[CustomTaskMetric] = { + Array(new SimpleWriterCustomDriverTaskMetric(11111)) } } } From 8706adfb6395951d822d41facc34bbbe22f7e69d Mon Sep 17 00:00:00 2001 From: zouxxyy Date: Mon, 12 Aug 2024 19:42:24 +0800 Subject: [PATCH 2/3] fix test --- .../spark/sql/connector/V1WriteFallbackSuite.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index 278ec3b63c180..1f019d1e7dd91 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -209,10 +209,13 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before .config(V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[V1FallbackTableCatalog].getName) .getOrCreate() val df = session.createDataFrame(Seq((1, "p1"), (2, "p2"), (3, "p2"))) - df.write.partitionBy("_2").mode("append").format(v2Format).saveAsTable("test") + .toDF("id", "p_49210") + df.write.partitionBy("p_49210").mode("append").format(v2Format).saveAsTable("test") val statusStore = session.sharedState.statusStore val execId = statusStore.executionsList() - .find(x => x.physicalPlanDescription.contains("AppendData")).get.executionId + .find(x => x.metrics.exists(_.name.equals("number of written partitions") + && x.physicalPlanDescription.contains("p_49210"))) + .get.executionId val metrics = statusStore.executionMetrics(execId) assert(metrics.head._2 == "2") } finally { @@ -488,7 +491,7 @@ object OnlyOnceOptimizerRule extends Rule[LogicalPlan] { class WrittenPartitionDriverMetric extends CustomMetric { override def name(): String = "number_of_written_partitions" - override def description(): String = "Simple custom driver metrics: number of written partitions" + override def description(): String = "number of written partitions" override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = taskMetrics.sum.toString } From f68e20d093001ec3502c9e79327c1cfb0ddd0c51 Mon Sep 17 00:00:00 2001 From: zouxxyy Date: Tue, 13 Aug 2024 10:08:22 +0800 Subject: [PATCH 3/3] update --- .../datasources/v2/DataSourceV2ScanExecBase.scala | 4 ++-- .../execution/datasources/v2/V1FallbackWriters.scala | 4 ++-- .../datasources/v2/WriteToDataSourceV2Exec.scala | 4 ++-- .../apache/spark/sql/execution/metric/SQLMetrics.scala | 10 +++++----- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index b29643b5cfdb9..9be58f608b4d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.Utils trait DataSourceV2ScanExecBase extends LeafExecNode { lazy val customMetrics: Map[String, SQLMetric] = - SQLMetrics.createV2CustomMetrics(sparkContext, scan) + SQLMetrics.createV2CustomMetrics(sparkContext, scan.supportedCustomMetrics()) override lazy val metrics: Map[String, SQLMetric] = { Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) ++ @@ -189,7 +189,7 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { } protected def postDriverMetrics(): Unit = { - SQLMetrics.postV2DriverMetrics(sparkContext, scan, metrics) + SQLMetrics.postV2DriverMetrics(sparkContext, scan.reportDriverMetrics(), metrics) } override def doExecuteColumnar(): RDD[ColumnarBatch] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala index 9726620f389e7..a0fc7b7980d9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala @@ -64,14 +64,14 @@ sealed trait V1FallbackWriters extends LeafV2CommandExec with SupportsV1Write { def write: V1Write protected val customMetrics: Map[String, SQLMetric] = - SQLMetrics.createV2CustomMetrics(sparkContext, write) + SQLMetrics.createV2CustomMetrics(sparkContext, write.supportedCustomMetrics()) override lazy val metrics: Map[String, SQLMetric] = customMetrics override def run(): Seq[InternalRow] = { val writtenRows = writeWithV1(write.toInsertableRelation) refreshCache() - SQLMetrics.postV2DriverMetrics(sparkContext, write, metrics) + SQLMetrics.postV2DriverMetrics(sparkContext, write.reportDriverMetrics(), metrics) writtenRows } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 5e7ff1539fe59..a433a8e061ee1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -330,12 +330,12 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec { def write: Write override val customMetrics: Map[String, SQLMetric] = - SQLMetrics.createV2CustomMetrics(sparkContext, write) + SQLMetrics.createV2CustomMetrics(sparkContext, write.supportedCustomMetrics()) override protected def run(): Seq[InternalRow] = { val writtenRows = writeWithV2(write.toBatch) refreshCache() - SQLMetrics.postV2DriverMetrics(sparkContext, write, metrics) + SQLMetrics.postV2DriverMetrics(sparkContext, write.reportDriverMetrics(), metrics) writtenRows } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 9b5bf63553422..b40d148404fa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -26,7 +26,7 @@ import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.scheduler.AccumulableInfo -import org.apache.spark.sql.connector.metric.{CustomMetric, SupportCustomMetrics} +import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates @@ -155,15 +155,15 @@ object SQLMetrics { } def createV2CustomMetrics( - sc: SparkContext, supportCustomMetrics: SupportCustomMetrics): Map[String, SQLMetric] = { - supportCustomMetrics.supportedCustomMetrics().map { customMetric => + sc: SparkContext, customMetrics: Array[CustomMetric]): Map[String, SQLMetric] = { + customMetrics.map { customMetric => customMetric.name() -> SQLMetrics.createV2CustomMetric(sc, customMetric) }.toMap } - def postV2DriverMetrics(sc: SparkContext, supportCustomMetrics: SupportCustomMetrics, + def postV2DriverMetrics(sc: SparkContext, driverMetrics: Array[CustomTaskMetric], metrics: Map[String, SQLMetric]): Unit = { - val driveSQLMetrics = supportCustomMetrics.reportDriverMetrics().map(customTaskMetric => { + val driveSQLMetrics = driverMetrics.map(customTaskMetric => { val metric = metrics(customTaskMetric.name()) metric.set(customTaskMetric.value()) metric