diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index d4b50655c721..73b180468d36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -255,40 +255,12 @@ trait ProgressReporter extends Logging { } if (onlyDataSourceV2Sources) { - // DataSourceV2ScanExec is the execution plan leaf that is responsible for reading data - // from a V2 source and has a direct reference to the V2 source that generated it. Each - // DataSourceV2ScanExec records the number of rows it has read using SQLMetrics. However, - // just collecting all DataSourceV2ScanExec nodes and getting the metric is not correct as - // a DataSourceV2ScanExec instance may be referred to in the execution plan from two (or - // even multiple times) points and considering it twice will lead to double counting. We - // can't dedup them using their hashcode either because two different instances of - // DataSourceV2ScanExec can have the same hashcode but account for separate sets of - // records read, and deduping them to consider only one of them would be undercounting the - // records read. Therefore the right way to do this is to consider the unique instances of - // DataSourceV2ScanExec (using their identity hash codes) and get metrics from them. - // Hence we calculate in the following way. - // - // 1. Collect all the unique DataSourceV2ScanExec instances using IdentityHashMap. - // - // 2. Extract the source and the number of rows read from the DataSourceV2ScanExec instanes. - // - // 3. Multiple DataSourceV2ScanExec instance may refer to the same source (can happen with - // self-unions or self-joins). Add up the number of rows for each unique source. - val uniqueStreamingExecLeavesMap = - new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]() - - lastExecution.executedPlan.collectLeaves().foreach { + val sourceToInputRowsTuples = lastExecution.executedPlan.collect { case s: DataSourceV2ScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] => - uniqueStreamingExecLeavesMap.put(s, s) - case _ => - } - - val sourceToInputRowsTuples = - uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf => - val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) - val source = execLeaf.readSupport.asInstanceOf[BaseStreamingSource] + val numRows = s.metrics.get("numOutputRows").map(_.value).getOrElse(0L) + val source = s.readSupport.asInstanceOf[BaseStreamingSource] source -> numRows - }.toSeq + } logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) sumRows(sourceToInputRowsTuples) } else {