diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ab6031c436e9..9d9b020309d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -250,7 +250,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val codec = CompressionCodec.createCodec(SparkEnv.get.conf) val bos = new ByteArrayOutputStream() val out = new DataOutputStream(codec.compressedOutputStream(bos)) - while (iter.hasNext && (n < 0 || count < n)) { + // `iter.hasNext` may produce one row and buffer it, we should only call it when the limit is + // not hit. + while ((n < 0 || count < n) && iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] out.writeInt(row.getSizeInBytes) row.writeToStream(out, buffer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index d45eb0c27a6b..085a44548848 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -24,7 +24,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.ui.SQLAppStatusStore +import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -517,4 +517,57 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared test("writing data out metrics with dynamic partition: parquet") { testMetricsDynamicPartition("parquet", "parquet", "t1") } + + test("SPARK-25602: SparkPlan.getByteArrayRdd should not consume the input when not necessary") { + def checkFilterAndRangeMetrics( + df: DataFrame, + filterNumOutputs: Int, + rangeNumOutputs: Int): Unit = { + var filter: FilterExec = null + var range: RangeExec = null + val collectFilterAndRange: SparkPlan => Unit = { + case f: FilterExec => + assert(filter == null, "the query should only have one Filter") + filter = f + case r: RangeExec => + assert(range == null, "the query should only have one Range") + range = r + case _ => + } + if (SQLConf.get.wholeStageEnabled) { + df.queryExecution.executedPlan.foreach { + case w: WholeStageCodegenExec => + w.child.foreach(collectFilterAndRange) + case _ => + } + } else { + df.queryExecution.executedPlan.foreach(collectFilterAndRange) + } + + assert(filter != null && range != null, "the query doesn't have Filter and Range") + assert(filter.metrics("numOutputRows").value == filterNumOutputs) + assert(range.metrics("numOutputRows").value == rangeNumOutputs) + } + + val df = spark.range(0, 3000, 1, 2).toDF().filter('id % 3 === 0) + val df2 = df.limit(2) + Seq(true, false).foreach { wholeStageEnabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholeStageEnabled.toString) { + df.collect() + checkFilterAndRangeMetrics(df, filterNumOutputs = 1000, rangeNumOutputs = 3000) + + df.queryExecution.executedPlan.foreach(_.resetMetrics()) + // For each partition, we get 2 rows. Then the Filter should produce 2 rows per-partition, + // and Range should produce 1000 rows (one batch) per-partition. Totally Filter produces + // 4 rows, and Range produces 2000 rows. + df.queryExecution.toRdd.mapPartitions(_.take(2)).collect() + checkFilterAndRangeMetrics(df, filterNumOutputs = 4, rangeNumOutputs = 2000) + + // Top-most limit will call `CollectLimitExec.executeCollect`, which will only run the first + // task, so totally the Filter produces 2 rows, and Range produces 1000 rows (one batch). + df2.collect() + checkFilterAndRangeMetrics(df2, filterNumOutputs = 2, rangeNumOutputs = 1000) + } + } + } }