diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/Write.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/Write.java index dc5712e93f470..f62d194fa7f9f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/Write.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/Write.java @@ -24,6 +24,7 @@ import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCapability; import org.apache.spark.sql.connector.metric.CustomMetric; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; import org.apache.spark.sql.connector.write.streaming.StreamingWrite; /** @@ -76,4 +77,14 @@ default StreamingWrite toStreaming() { default CustomMetric[] supportedCustomMetrics() { return new CustomMetric[]{}; } + + /** + * Returns an array of custom metrics which are collected with values at the driver side only. + * Note that these metrics must be included in the supported custom metrics reported by + * `supportedCustomMetrics`. + */ + default CustomTaskMetric[] reportDriverMetrics() { + return new CustomTaskMetric[]{}; + } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 852e39931626d..497ef848ac78f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils} import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions._ -import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} +import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, CustomTaskMetric} import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.read.colstats.{ColumnStatistics, Histogram, HistogramBin} import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} @@ -512,7 +512,11 @@ abstract class InMemoryBaseTable( } override def supportedCustomMetrics(): Array[CustomMetric] = { - Array(new InMemorySimpleCustomMetric) + Array(new InMemorySimpleCustomMetric, new InMemoryCustomDriverMetric) + } + + override def reportDriverMetrics(): Array[CustomTaskMetric] = { + Array(new InMemoryCustomDriverTaskMetric(rows.size)) } } } @@ -754,3 +758,13 @@ class InMemorySimpleCustomMetric extends CustomMetric { s"in-memory rows: ${taskMetrics.sum}" } } + +class InMemoryCustomDriverMetric extends CustomSumMetric { + override def name(): String = "number_of_rows_from_driver" + override def description(): String = "number of rows from driver" +} + +class InMemoryCustomDriverTaskMetric(value: Long) extends CustomTaskMetric { + override def name(): String = "number_of_rows_from_driver" + override def value(): Long = value +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 5885ec0afadcd..b238b0ce9760c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, PhysicalWriteInfoImpl, Write, WriterCommitMessage} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution, UnaryExecNode} import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric, SQLMetrics} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{LongAccumulator, Utils} @@ -341,9 +341,22 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec { override protected def run(): Seq[InternalRow] = { val writtenRows = writeWithV2(write.toBatch) + postDriverMetrics() refreshCache() writtenRows } + + protected def postDriverMetrics(): Unit = { + val driveSQLMetrics = write.reportDriverMetrics().map(customTaskMetric => { + val metric = metrics(customTaskMetric.name()) + metric.set(customTaskMetric.value()) + metric + }) + + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, + driveSQLMetrics.toImmutableArraySeq) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala index dee8d7ac3e794..7094404b3c1dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala @@ -42,7 +42,7 @@ class InMemoryTableMetricSuite spark.sessionState.conf.clear() } - private def testMetricOnDSv2(func: String => Unit, checker: Map[Long, String] => Unit): Unit = { + private def testMetricOnDSv2(func: String => Unit, checker: Map[String, String] => Unit): Unit = { withTable("testcat.table_name") { val statusStore = spark.sharedState.statusStore val oldCount = statusStore.executionsList().size @@ -67,8 +67,14 @@ class InMemoryTableMetricSuite statusStore.executionsList().last.metricValues != null) } - val execId = statusStore.executionsList().last.executionId - val metrics = statusStore.executionMetrics(execId) + val exec = statusStore.executionsList().last + val execId = exec.executionId + val sqlMetrics = exec.metrics.map { metric => + metric.accumulatorId -> metric.name + }.toMap + val metrics = statusStore.executionMetrics(execId).map { case (k, v) => + sqlMetrics(k) -> v + } checker(metrics) } } @@ -79,8 +85,8 @@ class InMemoryTableMetricSuite val v2Writer = df.writeTo(table) v2Writer.append() }, metrics => { - val customMetric = metrics.find(_._2 == "in-memory rows: 1") - assert(customMetric.isDefined) + assert(metrics.get("number of rows in buffer").contains("in-memory rows: 1")) + assert(metrics.get("number of rows from driver").contains("1")) }) } @@ -90,8 +96,8 @@ class InMemoryTableMetricSuite val v2Writer = df.writeTo(table) v2Writer.overwrite(lit(true)) }, metrics => { - val customMetric = metrics.find(_._2 == "in-memory rows: 3") - assert(customMetric.isDefined) + assert(metrics.get("number of rows in buffer").contains("in-memory rows: 3")) + assert(metrics.get("number of rows from driver").contains("3")) }) } }