From 3d24f79f2cb46b3cea80aa7c9b37d5e9653c95f1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 21 Oct 2016 23:25:46 +0000 Subject: [PATCH 1/7] Remove shuffle codes in CollectLimitExec. --- .../sql/execution/BufferedRowIterator.java | 2 +- .../spark/sql/execution/SparkStrategies.scala | 4 +++- .../apache/spark/sql/execution/limit.scala | 19 +++++++++---------- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 086547c793e3b..2c5a59f9f7401 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -39,7 +39,7 @@ public abstract class BufferedRowIterator { protected int partitionIndex = -1; public boolean hasNext() throws IOException { - if (currentRows.isEmpty()) { + if (!shouldStop()) { processNext(); } return !currentRows.isEmpty(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7cfae5ce283bf..6a1b2634a900f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -72,7 +72,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.TakeOrderedAndProjectExec( limit, order, projectList, planLater(child)) :: Nil case logical.Limit(IntegerLiteral(limit), child) => - execution.CollectLimitExec(limit, planLater(child)) :: Nil + execution.CollectLimitExec( + limit, + execution.LocalLimitExec(limit, planLater(child))) :: Nil case other => planLater(other) :: Nil } case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 86a8770715600..ebd5790bbd14f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange} import org.apache.spark.util.Utils @@ -36,14 +36,14 @@ import org.apache.spark.util.Utils case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = SinglePartition - override def executeCollect(): Array[InternalRow] = child.executeTake(limit) - private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil + override def executeCollect(): Array[InternalRow] = child match { + case e: Exchange => e.child.executeTake(limit) + case _ => child.executeTake(limit) + } + protected override def doExecute(): RDD[InternalRow] = { - val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) - val shuffled = new ShuffledRowRDD( - ShuffleExchange.prepareShuffleDependency( - locallyLimited, child.output, SinglePartition, serializer)) - shuffled.mapPartitionsInternal(_.take(limit)) + child.execute().mapPartitionsInternal(_.take(limit)) } } @@ -83,9 +83,8 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { s""" | if ($countTerm < $limit) { | $countTerm += 1; + | if ($countTerm == $limit) $stopEarly = true; | ${consume(ctx, input)} - | } else { - | $stopEarly = true; | } """.stripMargin } From 76a3eaf16ca68f1c0329cead866f8e2941285507 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 24 Oct 2016 03:38:18 +0000 Subject: [PATCH 2/7] Fix test. --- .../sql/execution/BufferedRowIterator.java | 12 ++++++++++- .../sql/execution/LocalTableScanExec.scala | 21 ++++++++++++++++++- .../aggregate/HashAggregateExec.scala | 9 ++++++++ .../apache/spark/sql/execution/limit.scala | 7 ++++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 16 ++++++++++---- 5 files changed, 56 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 2c5a59f9f7401..350582ca4d371 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -42,7 +42,12 @@ public boolean hasNext() throws IOException { if (!shouldStop()) { processNext(); } - return !currentRows.isEmpty(); + boolean hasNext = !currentRows.isEmpty(); + // If no more data available, releases resource if necessary. + if (!hasNext) { + releaseResource(); + } + return hasNext; } public InternalRow next() { @@ -91,4 +96,9 @@ protected void incPeakExecutionMemory(long size) { * After it's called, if currentRow is still null, it means no more rows left. */ protected abstract void processNext() throws IOException; + + /** + * Releases resources if necessary. No-op in default. + */ + protected void releaseResource() {} } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index e366b9af35c62..51d8891f933ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.execution.metric.SQLMetrics @@ -28,11 +29,13 @@ import org.apache.spark.sql.execution.metric.SQLMetrics */ case class LocalTableScanExec( output: Seq[Attribute], - rows: Seq[InternalRow]) extends LeafExecNode { + rows: Seq[InternalRow]) extends LeafExecNode with CodegenSupport { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(rdd) + private val unsafeRows: Array[InternalRow] = { if (rows.isEmpty) { Array.empty @@ -47,6 +50,22 @@ case class LocalTableScanExec( private lazy val rdd = sqlContext.sparkContext.parallelize(unsafeRows, numParallelism) + protected override def doProduce(ctx: CodegenContext): String = { + val numOutput = metricTerm(ctx, "numOutputRows") + val input = ctx.freshName("input") + // Right now, LocalTableScanExec is only used when there is one upstream. + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val row = ctx.freshName("row") + s""" + | while ($input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | $numOutput.add(1); + | ${consume(ctx, null, row).trim} + | if (shouldStop()) return; + | } + """.stripMargin + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") rdd.map { r => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 06199ef3e8243..6d96ed84c99e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -661,6 +661,15 @@ case class HashAggregateExec( """.stripMargin } + ctx.addNewFunction("releaseResource", s""" + @Override + protected void releaseResource() { + $iterTerm.close(); + if ($sorterTerm == null) { + $hashMapTerm.free(); + } + } + """) val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index ebd5790bbd14f..e5ae6a23f62a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -37,9 +37,10 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = SinglePartition override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil - override def executeCollect(): Array[InternalRow] = child match { - case e: Exchange => e.child.executeTake(limit) - case _ => child.executeTake(limit) + override def executeCollect(): Array[InternalRow] = { + child.collect { + case l: LocalLimitExec => l + }.head.child.executeTake(limit) } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 60978efddd7f8..f390578b9c96a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.execution.LocalLimitExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -2684,11 +2685,18 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-17515: CollectLimit.execute() should perform per-partition limits") { + val df = spark.range(1, 100, 1, numPartitions = 10).limit(1) + val localLimit = df.queryExecution.executedPlan.collect { + case l: LocalLimitExec => l + } + assert(localLimit.nonEmpty) val numRecordsRead = spark.sparkContext.longAccumulator - spark.range(1, 100, 1, numPartitions = 10).map { x => - numRecordsRead.add(1) - x - }.limit(1).queryExecution.toRdd.count() + localLimit.head.execute().mapPartitionsInternal { iter => + iter.map { x => + numRecordsRead.add(1) + x + } + }.count assert(numRecordsRead.value === 10) } From 58e8383c94db8837751fe5eb9d14c88ca614ff2b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 25 Oct 2016 06:13:30 +0000 Subject: [PATCH 3/7] Remove CollectLimitExec. --- .../spark/sql/execution/SparkStrategies.scala | 4 --- .../sql/execution/WholeStageCodegenExec.scala | 6 ++++ .../apache/spark/sql/execution/limit.scala | 36 ++++++++----------- .../spark/sql/execution/PlannerSuite.scala | 8 ++--- 4 files changed, 25 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6a1b2634a900f..053126a6b3b46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -71,10 +71,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { logical.Project(projectList, logical.Sort(order, true, child))) => execution.TakeOrderedAndProjectExec( limit, order, projectList, planLater(child)) :: Nil - case logical.Limit(IntegerLiteral(limit), child) => - execution.CollectLimitExec( - limit, - execution.LocalLimitExec(limit, planLater(child))) :: Nil case other => planLater(other) :: Nil } case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 62bf6f4a81eec..c82f34d712653 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -294,6 +294,12 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def executeCollect(): Array[InternalRow] = child match { + // This happens when the user is collecting results back to the driver, we could skip + // the shuffling and scan increasingly the RDD to get the limited items. + case g: GlobalLimitExec => g.executeCollect() + case _ => super.executeCollect() + } override lazy val metrics = Map( "pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index e5ae6a23f62a1..72e054d32c3a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -27,27 +27,6 @@ import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange} import org.apache.spark.util.Utils -/** - * Take the first `limit` elements and collect them to a single partition. - * - * This operator will be used when a logical `Limit` operation is the final operator in an - * logical plan, which happens when the user is collecting results back to the driver. - */ -case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { - override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = SinglePartition - override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil - override def executeCollect(): Array[InternalRow] = { - child.collect { - case l: LocalLimitExec => l - }.head.child.executeTake(limit) - } - - protected override def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal(_.take(limit)) - } -} - /** * Helper trait which defines methods that are shared by both * [[LocalLimitExec]] and [[GlobalLimitExec]]. @@ -57,6 +36,18 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { override def output: Seq[Attribute] = child.output override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputPartitioning: Partitioning = child.outputPartitioning + override def executeCollect(): Array[InternalRow] = { + child match { + // Shuffling is inserted under whole stage codegen. + case InputAdapter(ShuffleExchange(_, WholeStageCodegenExec(l: LocalLimitExec), _)) => + l.executeCollect() + // Shuffling is inserted without whole stage codegen. + case ShuffleExchange(_, l: LocalLimitExec, _) => l.executeCollect() + // No shuffling happened. + case _ => child.executeTake(limit) + } + } + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } @@ -100,6 +91,9 @@ case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { /** * Take the first `limit` elements of the child's single output partition. + * If this is the final operator in physical plan, which happens when the user is collecting + * results back to the driver, we could skip the shuffling and scan increasingly the RDD to + * get the limited items. */ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 375da224aaa7f..f53d0d65bd899 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -178,10 +178,10 @@ class PlannerSuite extends SharedSQLContext { assert(planned.output === testData.select('value, 'key).logicalPlan.output) } - test("terminal limits that are not handled by TakeOrderedAndProject should use CollectLimit") { + test("terminal limits that are not handled by TakeOrderedAndProject should use GlobalLimit") { val query = testData.select('value).limit(2) val planned = query.queryExecution.sparkPlan - assert(planned.isInstanceOf[CollectLimitExec]) + assert(planned.isInstanceOf[GlobalLimitExec]) assert(planned.output === testData.select('value).logicalPlan.output) } @@ -191,10 +191,10 @@ class PlannerSuite extends SharedSQLContext { assert(planned.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined) } - test("CollectLimit can appear in the middle of a plan when caching is used") { + test("GlobalLimit can appear in the middle of a plan when caching is used") { val query = testData.select('key, 'value).limit(2).cache() val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] - assert(planned.child.isInstanceOf[CollectLimitExec]) + assert(planned.child.isInstanceOf[GlobalLimitExec]) } test("PartitioningCollection") { From 82ebff411540df45ae214ac16fddf9d2fba212fd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 25 Oct 2016 07:29:57 +0000 Subject: [PATCH 4/7] Polishing comment. --- .../apache/spark/sql/execution/WholeStageCodegenExec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index bc78234fee70c..5bf871296ed7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -299,8 +299,8 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def executeCollect(): Array[InternalRow] = child match { - // This happens when the user is collecting results back to the driver, we could skip - // the shuffling and scan increasingly the RDD to get the limited items. + // A physical Limit operator has optimized executeCollect which scans increasingly + // the RDD to get the limited items, without fully materializing the RDD. case g: GlobalLimitExec => g.executeCollect() case _ => super.executeCollect() } From 44c64e0e3ba49cd0d82817f866a29f9733bcb972 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 25 Oct 2016 08:24:00 +0000 Subject: [PATCH 5/7] fix test. --- .../scala/org/apache/spark/sql/execution/PlannerSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f53d0d65bd899..9be4ce062800b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -193,7 +193,8 @@ class PlannerSuite extends SharedSQLContext { test("GlobalLimit can appear in the middle of a plan when caching is used") { val query = testData.select('key, 'value).limit(2).cache() - val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] + val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation].child + .asInstanceOf[WholeStageCodegenExec] assert(planned.child.isInstanceOf[GlobalLimitExec]) } From 86b4e42af60f32b3a570af443d6264a3aa335204 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 26 Oct 2016 04:42:40 +0000 Subject: [PATCH 6/7] Minimize the necessary changes. --- .../sql/execution/BufferedRowIterator.java | 14 ++----------- .../sql/execution/LocalTableScanExec.scala | 21 +------------------ .../sql/execution/WholeStageCodegenExec.scala | 6 ------ .../aggregate/HashAggregateExec.scala | 9 -------- .../apache/spark/sql/execution/limit.scala | 18 +++++----------- .../execution/metric/SQLMetricsSuite.scala | 19 ++++++++++++----- 6 files changed, 22 insertions(+), 65 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 350582ca4d371..086547c793e3b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -39,15 +39,10 @@ public abstract class BufferedRowIterator { protected int partitionIndex = -1; public boolean hasNext() throws IOException { - if (!shouldStop()) { + if (currentRows.isEmpty()) { processNext(); } - boolean hasNext = !currentRows.isEmpty(); - // If no more data available, releases resource if necessary. - if (!hasNext) { - releaseResource(); - } - return hasNext; + return !currentRows.isEmpty(); } public InternalRow next() { @@ -96,9 +91,4 @@ protected void incPeakExecutionMemory(long size) { * After it's called, if currentRow is still null, it means no more rows left. */ protected abstract void processNext() throws IOException; - - /** - * Releases resources if necessary. No-op in default. - */ - protected void releaseResource() {} } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 51d8891f933ed..e366b9af35c62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.execution.metric.SQLMetrics @@ -29,13 +28,11 @@ import org.apache.spark.sql.execution.metric.SQLMetrics */ case class LocalTableScanExec( output: Seq[Attribute], - rows: Seq[InternalRow]) extends LeafExecNode with CodegenSupport { + rows: Seq[InternalRow]) extends LeafExecNode { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(rdd) - private val unsafeRows: Array[InternalRow] = { if (rows.isEmpty) { Array.empty @@ -50,22 +47,6 @@ case class LocalTableScanExec( private lazy val rdd = sqlContext.sparkContext.parallelize(unsafeRows, numParallelism) - protected override def doProduce(ctx: CodegenContext): String = { - val numOutput = metricTerm(ctx, "numOutputRows") - val input = ctx.freshName("input") - // Right now, LocalTableScanExec is only used when there is one upstream. - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") - val row = ctx.freshName("row") - s""" - | while ($input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); - | $numOutput.add(1); - | ${consume(ctx, null, row).trim} - | if (shouldStop()) return; - | } - """.stripMargin - } - protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") rdd.map { r => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 5bf871296ed7f..6303483f22fd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -298,12 +298,6 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co override def outputPartitioning: Partitioning = child.outputPartitioning override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def executeCollect(): Array[InternalRow] = child match { - // A physical Limit operator has optimized executeCollect which scans increasingly - // the RDD to get the limited items, without fully materializing the RDD. - case g: GlobalLimitExec => g.executeCollect() - case _ => super.executeCollect() - } override lazy val metrics = Map( "pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 738d92d78319b..4529ed067e565 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -663,15 +663,6 @@ case class HashAggregateExec( """.stripMargin } - ctx.addNewFunction("releaseResource", s""" - @Override - protected void releaseResource() { - $iterTerm.close(); - if ($sorterTerm == null) { - $hashMapTerm.free(); - } - } - """) val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 8b54462f88c0f..e42a466f6246b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.util.Utils /** @@ -33,17 +33,8 @@ import org.apache.spark.util.Utils trait BaseLimitExec extends UnaryExecNode with CodegenSupport { val limit: Int override def output: Seq[Attribute] = child.output - override def executeCollect(): Array[InternalRow] = { - child match { - // Shuffling is inserted under whole stage codegen. - case InputAdapter(ShuffleExchange(_, WholeStageCodegenExec(l: LocalLimitExec), _)) => - l.executeCollect() - // Shuffling is inserted without whole stage codegen. - case ShuffleExchange(_, l: LocalLimitExec, _) => l.executeCollect() - // No shuffling happened. - case _ => child.executeTake(limit) - } - } + override def executeCollect(): Array[InternalRow] = child.executeTake(limit) + override def executeTake(n: Int): Array[InternalRow] = child.executeTake(limit) protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) @@ -72,8 +63,9 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { s""" | if ($countTerm < $limit) { | $countTerm += 1; - | if ($countTerm == $limit) $stopEarly = true; | ${consume(ctx, input)} + | } else { + | $stopEarly = true; | } """.stripMargin } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 229d8814e0143..20322d50e4051 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -95,11 +95,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { assert(metrics1.contains("numOutputRows")) assert(metrics1("numOutputRows").value === 3) - val df2 = spark.createDataset(Seq(1, 2, 3)).limit(2) - df2.collect() - val metrics2 = df2.queryExecution.executedPlan.collectLeaves().head.metrics - assert(metrics2.contains("numOutputRows")) - assert(metrics2("numOutputRows").value === 2) + // Due to the iteration processing of WholeStage Codegen, + // limit(n) operator will pull n + 1 items from previous operator. So we tune off + // WholeStage Codegen here to test the metrics. + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + // A physical GlobalLimitExec and a LocalLimitExec constitute a logical Limit operator. + // When we ask 2 items, each partition will be asked 2 items from the LocalLimitExec. + val df2 = spark.createDataset(Seq(1, 2, 3)).limit(1) + df2.collect() + // The number of partitions of the LocalTableScanExec. + val parallelism = math.min(sparkContext.defaultParallelism, 3) + val metrics2 = df2.queryExecution.executedPlan.collectLeaves().head.metrics + assert(metrics2.contains("numOutputRows")) + assert(metrics2("numOutputRows").value === parallelism * 1) + } } test("Filter metrics") { From 89c0c6241720285fb19b8151a6882ceb8dc2563f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 27 Oct 2016 04:08:13 +0000 Subject: [PATCH 7/7] Revert CollectLimitExec back. --- .../spark/sql/execution/SparkStrategies.scala | 4 +++ .../apache/spark/sql/execution/limit.scala | 35 ++++++++++++++++--- .../spark/sql/execution/PlannerSuite.scala | 11 +++--- .../execution/metric/SQLMetricsSuite.scala | 21 +++++------ 4 files changed, 47 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 053126a6b3b46..6a1b2634a900f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -71,6 +71,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { logical.Project(projectList, logical.Sort(order, true, child))) => execution.TakeOrderedAndProjectExec( limit, order, projectList, planLater(child)) :: Nil + case logical.Limit(IntegerLiteral(limit), child) => + execution.CollectLimitExec( + limit, + execution.LocalLimitExec(limit, planLater(child))) :: Nil case other => planLater(other) :: Nil } case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index e42a466f6246b..6c5e997792839 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -26,6 +26,36 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.util.Utils +/** + * Take the first `limit` elements and collect them to a single partition. + * + * This operator will be used when a logical `Limit` operation is the final operator in an + * logical plan, which happens when the user is collecting results back to the driver. + */ +case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = SinglePartition + override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil + override def executeCollect(): Array[InternalRow] = child match { + // Shuffling injected. WholeStageCodegenExec enabled. + case ShuffleExchange(_, WholeStageCodegenExec(l: LocalLimitExec), _) => + l.child.executeTake(limit) + + // Shuffling injected. WholeStageCodegenExec disabled. + case ShuffleExchange(_, l: LocalLimitExec, _) => l.child.executeTake(limit) + + // No shuffled injected. WholeStageCodegenExec enabled. + case WholeStageCodegenExec(l: LocalLimitExec) => l.child.executeTake(limit) + + // No shuffling injected. WholeStageCodegenExec disabled. + case l: LocalLimitExec => l.child.executeTake(limit) + } + + protected override def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal(_.take(limit)) + } +} + /** * Helper trait which defines methods that are shared by both * [[LocalLimitExec]] and [[GlobalLimitExec]]. @@ -33,8 +63,6 @@ import org.apache.spark.util.Utils trait BaseLimitExec extends UnaryExecNode with CodegenSupport { val limit: Int override def output: Seq[Attribute] = child.output - override def executeCollect(): Array[InternalRow] = child.executeTake(limit) - override def executeTake(n: Int): Array[InternalRow] = child.executeTake(limit) protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) @@ -83,9 +111,6 @@ case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { /** * Take the first `limit` elements of the child's single output partition. - * If this is the final operator in physical plan, which happens when the user is collecting - * results back to the driver, we could skip the shuffling and scan increasingly the RDD to - * get the limited items. */ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 9be4ce062800b..375da224aaa7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -178,10 +178,10 @@ class PlannerSuite extends SharedSQLContext { assert(planned.output === testData.select('value, 'key).logicalPlan.output) } - test("terminal limits that are not handled by TakeOrderedAndProject should use GlobalLimit") { + test("terminal limits that are not handled by TakeOrderedAndProject should use CollectLimit") { val query = testData.select('value).limit(2) val planned = query.queryExecution.sparkPlan - assert(planned.isInstanceOf[GlobalLimitExec]) + assert(planned.isInstanceOf[CollectLimitExec]) assert(planned.output === testData.select('value).logicalPlan.output) } @@ -191,11 +191,10 @@ class PlannerSuite extends SharedSQLContext { assert(planned.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined) } - test("GlobalLimit can appear in the middle of a plan when caching is used") { + test("CollectLimit can appear in the middle of a plan when caching is used") { val query = testData.select('key, 'value).limit(2).cache() - val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation].child - .asInstanceOf[WholeStageCodegenExec] - assert(planned.child.isInstanceOf[GlobalLimitExec]) + val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] + assert(planned.child.isInstanceOf[CollectLimitExec]) } test("PartitioningCollection") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 20322d50e4051..57420f9af9946 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -95,19 +95,14 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { assert(metrics1.contains("numOutputRows")) assert(metrics1("numOutputRows").value === 3) - // Due to the iteration processing of WholeStage Codegen, - // limit(n) operator will pull n + 1 items from previous operator. So we tune off - // WholeStage Codegen here to test the metrics. - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { - // A physical GlobalLimitExec and a LocalLimitExec constitute a logical Limit operator. - // When we ask 2 items, each partition will be asked 2 items from the LocalLimitExec. - val df2 = spark.createDataset(Seq(1, 2, 3)).limit(1) - df2.collect() - // The number of partitions of the LocalTableScanExec. - val parallelism = math.min(sparkContext.defaultParallelism, 3) - val metrics2 = df2.queryExecution.executedPlan.collectLeaves().head.metrics - assert(metrics2.contains("numOutputRows")) - assert(metrics2("numOutputRows").value === parallelism * 1) + Seq("true", "false").map { codeGen => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codeGen) { + val df2 = spark.createDataset(Seq(1, 2, 3)).coalesce(1).limit(2) + assert(df2.collect().length === 2) + val metrics2 = df2.queryExecution.executedPlan.collectLeaves().head.metrics + assert(metrics2.contains("numOutputRows")) + assert(metrics2("numOutputRows").value === 2) + } } }