Skip to content

Commit

Permalink
Revert CollectLimitExec back.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Oct 27, 2016
1 parent 86b4e42 commit 89c0c62
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
logical.Project(projectList, logical.Sort(order, true, child))) =>
execution.TakeOrderedAndProjectExec(
limit, order, projectList, planLater(child)) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.CollectLimitExec(
limit,
execution.LocalLimitExec(limit, planLater(child))) :: Nil
case other => planLater(other) :: Nil
}
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
Expand Down
35 changes: 30 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,43 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.util.Utils

/**
* Take the first `limit` elements and collect them to a single partition.
*
* This operator will be used when a logical `Limit` operation is the final operator in an
* logical plan, which happens when the user is collecting results back to the driver.
*/
case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode {
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = SinglePartition
override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil
override def executeCollect(): Array[InternalRow] = child match {
// Shuffling injected. WholeStageCodegenExec enabled.
case ShuffleExchange(_, WholeStageCodegenExec(l: LocalLimitExec), _) =>
l.child.executeTake(limit)

// Shuffling injected. WholeStageCodegenExec disabled.
case ShuffleExchange(_, l: LocalLimitExec, _) => l.child.executeTake(limit)

// No shuffled injected. WholeStageCodegenExec enabled.
case WholeStageCodegenExec(l: LocalLimitExec) => l.child.executeTake(limit)

// No shuffling injected. WholeStageCodegenExec disabled.
case l: LocalLimitExec => l.child.executeTake(limit)
}

protected override def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal(_.take(limit))
}
}

/**
* Helper trait which defines methods that are shared by both
* [[LocalLimitExec]] and [[GlobalLimitExec]].
*/
trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
val limit: Int
override def output: Seq[Attribute] = child.output
override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
override def executeTake(n: Int): Array[InternalRow] = child.executeTake(limit)

protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
iter.take(limit)
Expand Down Expand Up @@ -83,9 +111,6 @@ case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec {

/**
* Take the first `limit` elements of the child's single output partition.
* If this is the final operator in physical plan, which happens when the user is collecting
* results back to the driver, we could skip the shuffling and scan increasingly the RDD to
* get the limited items.
*/
case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ class PlannerSuite extends SharedSQLContext {
assert(planned.output === testData.select('value, 'key).logicalPlan.output)
}

test("terminal limits that are not handled by TakeOrderedAndProject should use GlobalLimit") {
test("terminal limits that are not handled by TakeOrderedAndProject should use CollectLimit") {
val query = testData.select('value).limit(2)
val planned = query.queryExecution.sparkPlan
assert(planned.isInstanceOf[GlobalLimitExec])
assert(planned.isInstanceOf[CollectLimitExec])
assert(planned.output === testData.select('value).logicalPlan.output)
}

Expand All @@ -191,11 +191,10 @@ class PlannerSuite extends SharedSQLContext {
assert(planned.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined)
}

test("GlobalLimit can appear in the middle of a plan when caching is used") {
test("CollectLimit can appear in the middle of a plan when caching is used") {
val query = testData.select('key, 'value).limit(2).cache()
val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation].child
.asInstanceOf[WholeStageCodegenExec]
assert(planned.child.isInstanceOf[GlobalLimitExec])
val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation]
assert(planned.child.isInstanceOf[CollectLimitExec])
}

test("PartitioningCollection") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,14 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
assert(metrics1.contains("numOutputRows"))
assert(metrics1("numOutputRows").value === 3)

// Due to the iteration processing of WholeStage Codegen,
// limit(n) operator will pull n + 1 items from previous operator. So we tune off
// WholeStage Codegen here to test the metrics.
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
// A physical GlobalLimitExec and a LocalLimitExec constitute a logical Limit operator.
// When we ask 2 items, each partition will be asked 2 items from the LocalLimitExec.
val df2 = spark.createDataset(Seq(1, 2, 3)).limit(1)
df2.collect()
// The number of partitions of the LocalTableScanExec.
val parallelism = math.min(sparkContext.defaultParallelism, 3)
val metrics2 = df2.queryExecution.executedPlan.collectLeaves().head.metrics
assert(metrics2.contains("numOutputRows"))
assert(metrics2("numOutputRows").value === parallelism * 1)
Seq("true", "false").map { codeGen =>
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codeGen) {
val df2 = spark.createDataset(Seq(1, 2, 3)).coalesce(1).limit(2)
assert(df2.collect().length === 2)
val metrics2 = df2.queryExecution.executedPlan.collectLeaves().head.metrics
assert(metrics2.contains("numOutputRows"))
assert(metrics2("numOutputRows").value === 2)
}
}
}

Expand Down

0 comments on commit 89c0c62

Please sign in to comment.