diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 8e246dbbf5d7..e5f008804ee5 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -35,9 +35,11 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} @@ -598,18 +600,37 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { val join = values.join(values, "key") - testStream(join)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2), - CheckAnswer((1, 1, 1), (2, 2, 2)), - AddKafkaData(Set(topic), 6, 3), - CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)), - AssertOnQuery { q => + def checkQuery(check: AssertOnQuery): Unit = { + testStream(join)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2), + CheckAnswer((1, 1, 1), (2, 2, 2)), + AddKafkaData(Set(topic), 6, 3), + CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)), + check + ) + } + + withSQLConf(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { + checkQuery(AssertOnQuery { q => assert(q.availableOffsets.iterator.size == 1) + // The kafka source is scanned twice because of self-join + assert(q.recentProgress.map(_.numInputRows).sum == 8) + true + }) + } + + withSQLConf(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "true") { + checkQuery(AssertOnQuery { q => + assert(q.availableOffsets.iterator.size == 1) + assert(q.lastExecution.executedPlan.collect { + case r: ReusedExchangeExec => r + }.length == 1) + // The kafka source is scanned only once because of exchange reuse. assert(q.recentProgress.map(_.numInputRows).sum == 4) true - } - ) + }) + } } test("read Kafka transactional messages: read_committed") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 73b180468d36..392229bcb5f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -240,9 +240,6 @@ trait ProgressReporter extends Logging { /** Extract number of input sources for each streaming source in plan */ private def extractSourceToNumInputRows(): Map[BaseStreamingSource, Long] = { - import java.util.IdentityHashMap - import scala.collection.JavaConverters._ - def sumRows(tuples: Seq[(BaseStreamingSource, Long)]): Map[BaseStreamingSource, Long] = { tuples.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source } @@ -255,6 +252,9 @@ trait ProgressReporter extends Logging { } if (onlyDataSourceV2Sources) { + // It's possible that multiple DataSourceV2ScanExec instances may refer to the same source + // (can happen with self-unions or self-joins). This means the source is scanned multiple + // times in the query, we should count the numRows for each scan. val sourceToInputRowsTuples = lastExecution.executedPlan.collect { case s: DataSourceV2ScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] => val numRows = s.metrics.get("numOutputRows").map(_.value).getOrElse(0L) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 1dd817545a96..c170641372d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ @@ -500,29 +501,52 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery { q => val lastProgress = getLastProgressWithData(q) assert(lastProgress.nonEmpty) - assert(lastProgress.get.numInputRows == 6) assert(lastProgress.get.sources.length == 1) - assert(lastProgress.get.sources(0).numInputRows == 6) + // The source is scanned twice because of self-union + assert(lastProgress.get.numInputRows == 6) true } ) } test("input row calculation with same V2 source used twice in self-join") { - val streamInput = MemoryStream[Int] - val df = streamInput.toDF() - testStream(df.join(df, "value"), useV2Sink = true)( - AddData(streamInput, 1, 2, 3), - CheckAnswer(1, 2, 3), - AssertOnQuery { q => + def checkQuery(check: AssertOnQuery): Unit = { + val memoryStream = MemoryStream[Int] + // TODO: currently the streaming framework always add a dummy Project above streaming source + // relation, which breaks exchange reuse, as the optimizer will remove Project from one side. + // Here we manually add a useful Project, to trigger exchange reuse. + val streamDF = memoryStream.toDF().select('value + 0 as "v") + testStream(streamDF.join(streamDF, "v"), useV2Sink = true)( + AddData(memoryStream, 1, 2, 3), + CheckAnswer(1, 2, 3), + check + ) + } + + withSQLConf(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { + checkQuery(AssertOnQuery { q => val lastProgress = getLastProgressWithData(q) assert(lastProgress.nonEmpty) + assert(lastProgress.get.sources.length == 1) + // The source is scanned twice because of self-join assert(lastProgress.get.numInputRows == 6) + true + }) + } + + withSQLConf(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "true") { + checkQuery(AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) assert(lastProgress.get.sources.length == 1) - assert(lastProgress.get.sources(0).numInputRows == 6) + assert(q.lastExecution.executedPlan.collect { + case r: ReusedExchangeExec => r + }.length == 1) + // The source is scanned only once because of exchange reuse + assert(lastProgress.get.numInputRows == 3) true - } - ) + }) + } } test("input row calculation with trigger having data for only one of two V2 sources") {