diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 7850dfa39d16..217a1d5750d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition) @@ -66,7 +66,12 @@ class DataSourceRDD( new PartitionIterator[InternalRow](rowReader, customMetrics)) (iter, rowReader) } - context.addTaskCompletionListener[Unit](_ => reader.close()) + context.addTaskCompletionListener[Unit] { _ => + // In case of early stopping before consuming the entire iterator, + // we need to do one more metric update at the end of the task. + CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics) + reader.close() + } // TODO: SPARK-25083 remove the type erasure hack in data source scan new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]]) } @@ -81,6 +86,8 @@ private class PartitionIterator[T]( customMetrics: Map[String, SQLMetric]) extends Iterator[T] { private[this] var valuePrepared = false + private var numRow = 0L + override def hasNext: Boolean = { if (!valuePrepared) { valuePrepared = reader.next() @@ -92,12 +99,10 @@ private class PartitionIterator[T]( if (!hasNext) { throw QueryExecutionErrors.endOfStreamError() } - reader.currentMetricsValues.foreach { metric => - assert(customMetrics.contains(metric.name()), - s"Custom metrics ${customMetrics.keys.mkString(", ")} do not contain the metric " + - s"${metric.name()}") - customMetrics(metric.name()).set(metric.value()) + if (numRow % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) { + CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics) } + numRow += 1 valuePrepared = false reader.get() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala index f2449a1ec58f..3e6cad2676e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.metric -import org.apache.spark.sql.connector.metric.CustomMetric +import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} object CustomMetrics { private[spark] val V2_CUSTOM = "v2Custom" + private[spark] val NUM_ROWS_PER_UPDATE = 100 + /** * Given a class name, builds and returns a metric type for a V2 custom metric class * `CustomMetric`. @@ -41,4 +43,15 @@ object CustomMetrics { None } } + + /** + * Updates given custom metrics. + */ + def updateMetrics( + currentMetricsValues: Seq[CustomTaskMetric], + customMetrics: Map[String, SQLMetric]): Unit = { + currentMetricsValues.foreach { metric => + customMetrics(metric.name()).set(metric.value()) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index 4e32cefbe31a..6d27961fa0bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -22,7 +22,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.streaming.ContinuousPartitionReaderFactory -import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.sql.types.StructType import org.apache.spark.util.NextIterator @@ -92,10 +92,13 @@ class ContinuousDataSourceRDD( val partitionReader = readerForPartition.getPartitionReader() new NextIterator[InternalRow] { + private var numRow = 0L + override def getNext(): InternalRow = { - partitionReader.currentMetricsValues.foreach { metric => - customMetrics(metric.name()).set(metric.value()) + if (numRow % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) { + CustomMetrics.updateMetrics(partitionReader.currentMetricsValues, customMetrics) } + numRow += 1 readerForPartition.next() match { case null => finished = true