Skip to content

Commit c02324c

Browse files
committed
Use requiredChildDistribution in Limit
1 parent 089f9f5 commit c02324c

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -346,12 +346,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
346346
execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
347347
case logical.LocalRelation(output, data) =>
348348
LocalTableScan(output, data) :: Nil
349-
case logical.Limit(IntegerLiteral(limit), child) => {
350-
val perPartitionLimit = execution.PartitionLocalLimit(limit, planLater(child))
351-
val globalLimit = execution.PartitionLocalLimit(
352-
limit, execution.Exchange(SinglePartition, Nil, perPartitionLimit))
349+
case logical.Limit(IntegerLiteral(limit), child) =>
350+
val perPartitionLimit = execution.Limit(global = false, limit, planLater(child))
351+
val globalLimit = execution.Limit(global = true, limit, perPartitionLimit)
353352
globalLimit :: Nil
354-
}
355353
case Unions(unionChildren) =>
356354
execution.Union(unionChildren.map(planLater)) :: Nil
357355
case logical.Except(left, right) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@ import org.apache.spark.sql.catalyst.errors._
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.plans.physical._
2727
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
28-
import org.apache.spark.util.CompletionIterator
2928
import org.apache.spark.util.collection.ExternalSorter
3029
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
31-
import org.apache.spark.util.{CompletionIterator, MutablePair}
32-
import org.apache.spark.{HashPartitioner, SparkEnv}
30+
import org.apache.spark.util.CompletionIterator
3331

3432
/**
3533
* :: DeveloperApi ::
@@ -109,11 +107,24 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan {
109107

110108
/**
111109
* :: DeveloperApi ::
112-
* Take the first `limit` elements from each partition.
110+
* Take the first `limit` elements.
111+
*
112+
* @param global if true, then this operator will take the first `limit` elements of the entire
113+
* input. If false, it will take the first `limit` elements of each partition.
114+
* @param limit the number of elements to take.
115+
* @param child the input data source.
113116
*/
114117
@DeveloperApi
115-
case class PartitionLocalLimit(limit: Int, child: SparkPlan)
118+
case class Limit(global: Boolean, limit: Int, child: SparkPlan)
116119
extends UnaryNode {
120+
override def requiredChildDistribution: List[Distribution] = {
121+
if (global) {
122+
AllTuples :: Nil
123+
} else {
124+
UnspecifiedDistribution :: Nil
125+
}
126+
}
127+
117128
override def output: Seq[Attribute] = child.output
118129

119130
override def executeCollect(): Array[Row] = child.executeTake(limit)

0 commit comments

Comments
 (0)