Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._

/**
* When planning take() or collect() operations, this special node that is inserted at the top of
* the logical plan before invoking the query planner.
*
* Rules can pattern-match on this node in order to apply transformations that only take effect
* at the top of the logical query plan.
*/
case class ReturnAnswer(child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}

case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)

Expand Down
130 changes: 71 additions & 59 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,69 @@ case class Exchange(

override def output: Seq[Attribute] = child.output

private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)

override protected def doPrepare(): Unit = {
// If an ExchangeCoordinator is needed, we register this Exchange operator
// to the coordinator when we do prepare. It is important to make sure
// we register this operator right before the execution instead of register it
// in the constructor because it is possible that we create new instances of
// Exchange operators when we transform the physical plan
// (then the ExchangeCoordinator will hold references of unneeded Exchanges).
// So, we should only call registerExchange just before we start to execute
// the plan.
coordinator match {
case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this)
case None =>
}
}

/**
* Returns a [[ShuffleDependency]] that will partition rows of its child based on
* the partitioning scheme defined in `newPartitioning`. Those partitions of
* the returned ShuffleDependency will be the input of shuffle.
*/
private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = {
Exchange.prepareShuffleDependency(child.execute(), child.output, newPartitioning, serializer)
}

/**
* Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset.
* This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional
* partition start indices array. If this optional array is defined, the returned
* [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array.
*/
private[sql] def preparePostShuffleRDD(
shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow],
specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = {
// If an array of partition start indices is provided, we need to use this array
// to create the ShuffledRowRDD. Also, we need to update newPartitioning to
// update the number of post-shuffle partitions.
specifiedPartitionStartIndices.foreach { indices =>
assert(newPartitioning.isInstanceOf[HashPartitioning])
newPartitioning = UnknownPartitioning(indices.length)
}
new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices)
}

protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
coordinator match {
case Some(exchangeCoordinator) =>
val shuffleRDD = exchangeCoordinator.postShuffleRDD(this)
assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
shuffleRDD
case None =>
val shuffleDependency = prepareShuffleDependency()
preparePostShuffleRDD(shuffleDependency)
}
}
}

object Exchange {
def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = {
Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator])
}

/**
* Determines whether records must be defensively copied before being sent to the shuffle.
* Several of Spark's shuffle components will buffer deserialized Java objects in memory. The
Expand All @@ -82,7 +145,7 @@ case class Exchange(
// passed instead of directly passing the number of partitions in order to guard against
// corner-cases where a partitioner constructed with `numPartitions` partitions may output
// fewer partitions (like RangePartitioner, for example).
val conf = child.sqlContext.sparkContext.conf
val conf = SparkEnv.get.conf
val shuffleManager = SparkEnv.get.shuffleManager
val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager]
val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
Expand Down Expand Up @@ -117,30 +180,16 @@ case class Exchange(
}
}

private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)

override protected def doPrepare(): Unit = {
// If an ExchangeCoordinator is needed, we register this Exchange operator
// to the coordinator when we do prepare. It is important to make sure
// we register this operator right before the execution instead of register it
// in the constructor because it is possible that we create new instances of
// Exchange operators when we transform the physical plan
// (then the ExchangeCoordinator will hold references of unneeded Exchanges).
// So, we should only call registerExchange just before we start to execute
// the plan.
coordinator match {
case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this)
case None =>
}
}

/**
* Returns a [[ShuffleDependency]] that will partition rows of its child based on
* the partitioning scheme defined in `newPartitioning`. Those partitions of
* the returned ShuffleDependency will be the input of shuffle.
*/
private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = {
val rdd = child.execute()
private[sql] def prepareShuffleDependency(
rdd: RDD[InternalRow],
outputAttributes: Seq[Attribute],
newPartitioning: Partitioning,
serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = {
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
case HashPartitioning(_, n) =>
Expand All @@ -160,7 +209,7 @@ case class Exchange(
// We need to use an interpreted ordering here because generated orderings cannot be
// serialized and this ordering needs to be created on the driver in order to be passed into
// Spark core code.
implicit val ordering = new InterpretedOrdering(sortingExpressions, child.output)
implicit val ordering = new InterpretedOrdering(sortingExpressions, outputAttributes)
new RangePartitioner(numPartitions, rddForSampling, ascending = true)
case SinglePartition =>
new Partitioner {
Expand All @@ -180,7 +229,7 @@ case class Exchange(
position
}
case h: HashPartitioning =>
val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, child.output)
val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
row => projection(row).getInt(0)
case RangePartitioning(_, _) | SinglePartition => identity
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
Expand Down Expand Up @@ -211,43 +260,6 @@ case class Exchange(

dependency
}

/**
* Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset.
* This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional
* partition start indices array. If this optional array is defined, the returned
* [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array.
*/
private[sql] def preparePostShuffleRDD(
shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow],
specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = {
// If an array of partition start indices is provided, we need to use this array
// to create the ShuffledRowRDD. Also, we need to update newPartitioning to
// update the number of post-shuffle partitions.
specifiedPartitionStartIndices.foreach { indices =>
assert(newPartitioning.isInstanceOf[HashPartitioning])
newPartitioning = UnknownPartitioning(indices.length)
}
new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices)
}

protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
coordinator match {
case Some(exchangeCoordinator) =>
val shuffleRDD = exchangeCoordinator.postShuffleRDD(this)
assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
shuffleRDD
case None =>
val shuffleDependency = prepareShuffleDependency()
preparePostShuffleRDD(shuffleDependency)
}
}
}

object Exchange {
def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = {
Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator])
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}

/**
* The primary workflow for executing relational queries using Spark. Designed to allow easy
Expand All @@ -44,7 +44,7 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {

lazy val sparkPlan: SparkPlan = {
SQLContext.setActive(sqlContext)
sqlContext.planner.plan(optimizedPlan).next()
sqlContext.planner.plan(ReturnAnswer(optimizedPlan)).next()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @marmbrus here to make sure it is safe

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this looks good to me.

}

// executedPlan should not be used to initialize any SparkPlan. It should be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
LocalTableScan(output, data) :: Nil
case logical.ReturnAnswer(logical.Limit(IntegerLiteral(limit), child)) =>
execution.CollectLimit(limit, planLater(child)) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.Limit(limit, planLater(child)) :: Nil
val perPartitionLimit = execution.LocalLimit(limit, planLater(child))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a test for this?

   select * from (select * from A limit 10 ) t1 UNION ALL (select * from B limit 10) t2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Today this will be planned as a scan -> local limit -> global limit for each branch of the union, which I believe is correct: this query should return at most 20 results.

What do you think this test should be asserting?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking that we did not have this kind of test case, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not allowed. See my PR: #10689

If you think this is valid, should we reopen it? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davies, to clarify, are you proposing that we need to add a test to ensure that global limits below union are not pulled up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gatorsmile: I think we're in agreement. To recap:

Before this patch, select * from (select * from A limit 10 ) t1 UNION ALL (select * from B limit 10) t2 should get planned as:

          ┌─────────┐         
          │  Union  │         
          └─────────┘         
               ▲              
      ┌────────┴────────┐     
      │                 │     
┌──────────┐      ┌──────────┐
│ Limit 10 │      │ Limit 10 │
└──────────┘      └──────────┘
      ▲                 ▲     
      │                 │     
┌──────────┐      ┌──────────┐
│  Scan A  │      │  Scan B  │
└──────────┘      └──────────┘

After the changes in this patch, this becomes

            ┌─────────┐             
            │  Union  │             
            └─────────┘             
                 ▲                  
        ┌────────┴──────────┐       
        │                   │       
┌───────────────┐   ┌──────────────┐
│GlobalLimit 10 │   │GlobalLimit 10│
└───────────────┘   └──────────────┘
        ▲                   ▲       
        │                   │       
┌──────────────┐    ┌──────────────┐
│LocalLimit 10 │    │ LocaLimit 10 │
└──────────────┘    └──────────────┘
        ▲                   ▲       
        │                   │       
  ┌──────────┐        ┌──────────┐  
  │  Scan A  │        │  Scan B  │  
  └──────────┘        └──────────┘  

What is not legal to do here is to pull the GlobalLimit up, so the following would be wrong:

         ┌───────────────┐          
         │GlobalLimit 10 │          
         └───────────────┘          
                 ▲                  
                 │                  
            ┌─────────┐             
            │  Union  │             
            └─────────┘             
                 ▲                  
        ┌────────┴──────────┐       
        │                   │       
┌──────────────┐    ┌──────────────┐
│LocalLimit 10 │    │ LocaLimit 10 │
└──────────────┘    └──────────────┘
        ▲                   ▲       
        │                   │       
  ┌──────────┐        ┌──────────┐  
  │  Scan A  │        │  Scan B  │  
  └──────────┘        └──────────┘  

That plan would be semantically equivalent to executing

  select * from (select * from A ) t1 UNION ALL (select * from B) t2 LIMIT 10

@davies, were you suggesting that the current planning is wrong? Or that we need more tests to guard against incorrect changes to limit planning? I don't believe that the changes in this patch will affect the planning of the case being described here, since we're not making any changes to limit pull-up or push-down. I do have a followup patch in the works which takes @gatorsmile's two limit-pushdown patches and rebases them on top of the changes here: master...JoshRosen:limit-pushdown-2. In that patch, I do plan to add more tests to handle these pushdown-related concerns.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JoshRosen Thank you for your explanation! That is so great that my previous PR are useful. So far, I am unable to find more operators for limit push down, except outer join and union all: #10454 and #10451

val globalLimit = execution.GlobalLimit(limit, perPartitionLimit)
globalLimit :: Nil
case logical.Union(unionChildren) =>
execution.Union(unionChildren.map(planLater)) :: Nil
case logical.Except(left, right) =>
Expand All @@ -357,6 +361,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
case logical.ReturnAnswer(child) => planLater(child) :: Nil
case _ => Nil
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@

package org.apache.spark.sql.execution

import org.apache.spark.{HashPartitioner, SparkEnv}
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.LongType
import org.apache.spark.util.MutablePair
import org.apache.spark.util.random.PoissonSampler

case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
Expand Down Expand Up @@ -299,96 +296,6 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan {
sparkContext.union(children.map(_.execute()))
}

/**
* Take the first limit elements. Note that the implementation is different depending on whether
* this is a terminal operator or not. If it is terminal and is invoked using executeCollect,
* this operator uses something similar to Spark's take method on the Spark driver. If it is not
* terminal or is invoked using execute, we first take the limit on each partition, and then
* repartition all the data to a single partition to compute the global limit.
*/
case class Limit(limit: Int, child: SparkPlan)
extends UnaryNode {
// TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
// partition local limit -> exchange into one partition -> partition local limit again

/** We must copy rows when sort based shuffle is on */
private def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]

override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = SinglePartition

override def executeCollect(): Array[InternalRow] = child.executeTake(limit)

protected override def doExecute(): RDD[InternalRow] = {
val rdd: RDD[_ <: Product2[Boolean, InternalRow]] = if (sortBasedShuffleOn) {
child.execute().mapPartitionsInternal { iter =>
iter.take(limit).map(row => (false, row.copy()))
}
} else {
child.execute().mapPartitionsInternal { iter =>
val mutablePair = new MutablePair[Boolean, InternalRow]()
iter.take(limit).map(row => mutablePair.update(false, row))
}
}
val part = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Boolean, InternalRow, InternalRow](rdd, part)
shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf))
shuffled.mapPartitionsInternal(_.take(limit).map(_._2))
}
}

/**
* Take the first limit elements as defined by the sortOrder, and do projection if needed.
* This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator,
* or having a [[Project]] operator between them.
* This could have been named TopK, but Spark's top operator does the opposite in ordering
* so we name it TakeOrdered to avoid confusion.
*/
case class TakeOrderedAndProject(
limit: Int,
sortOrder: Seq[SortOrder],
projectList: Option[Seq[NamedExpression]],
child: SparkPlan) extends UnaryNode {

override def output: Seq[Attribute] = {
val projectOutput = projectList.map(_.map(_.toAttribute))
projectOutput.getOrElse(child.output)
}

override def outputPartitioning: Partitioning = SinglePartition

// We need to use an interpreted ordering here because generated orderings cannot be serialized
// and this ordering needs to be created on the driver in order to be passed into Spark core code.
private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output)

private def collectData(): Array[InternalRow] = {
val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
if (projectList.isDefined) {
val proj = UnsafeProjection.create(projectList.get, child.output)
data.map(r => proj(r).copy())
} else {
data
}
}

override def executeCollect(): Array[InternalRow] = {
collectData()
}

// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1)

override def outputOrdering: Seq[SortOrder] = sortOrder

override def simpleString: String = {
val orderByString = sortOrder.mkString("[", ",", "]")
val outputString = output.mkString("[", ",", "]")

s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)"
}
}

/**
* Return a new RDD that has exactly `numPartitions` partitions.
* Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g.
Expand Down
Loading