Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ object MimaExcludes {
"org.apache.spark.api.java.JavaRDDLike.partitioner"),
// Mima false positive (was a private[spark] class)
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.util.collection.PairIterator")
"org.apache.spark.util.collection.PairIterator"),
// SQL execution is considered private.
excludePackage("org.apache.spark.sql.execution")
)
case v if v.startsWith("1.4") =>
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ object DefaultOptimizer extends Optimizer {
// SubQueries are only needed for analysis and can be removed before execution.
Batch("Remove SubQueries", FixedPoint(100),
EliminateSubQueries) ::
Batch("Distinct", FixedPoint(100),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: Seems that Once is enough. Also applies to the "Remove SubQueries" batch above.

ReplaceDistinctWithAggregate) ::
Batch("Operator Reordering", FixedPoint(100),
UnionPushdown,
CombineFilters,
Expand Down Expand Up @@ -696,3 +698,15 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] {
LocalRelation(projectList.map(_.toAttribute), data.map(projection))
}
}

/**
* Replaces logical [[Distinct]] operator with an [[Aggregate]] operator.
* {{{
* SELECT DISTINCT f1, f2 FROM t ==> SELECT f1, f2 FROM t GROUP BY f1, f2
* }}}
*/
object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Distinct(child) => Aggregate(child.output, child.output, child)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An example in the comment can be useful for understanding:

SELECT DISTINCT f1, f2 FROM t  ==>  SELECT f1, f2 FROM t GROUP BY f1, f2

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@ case class Sample(
override def output: Seq[Attribute] = child.output
}

/**
* Returns a new logical plan that dedups input rows.
*/
case class Distinct(child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReplaceDistinctWithAggregateSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("ProjectCollapsing", Once, ReplaceDistinctWithAggregate) :: Nil
}

test("replace distinct with aggregate") {
val input = LocalRelation('a.int, 'b.int)

val query = Distinct(input)
val optimized = Optimize.execute(query.analyze)

val correctAnswer = Aggregate(input.output, input.output, input)

comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
override def distinct: DataFrame = Distinct(logicalPlan)
override def distinct: DataFrame = dropDuplicates()

/**
* @group basic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case r: RunnableCommand => ExecutedCommand(r) :: Nil

case logical.Distinct(child) =>
execution.Distinct(partial = false,
execution.Distinct(partial = true, planLater(child))) :: Nil
throw new IllegalStateException(
"logical distinct operator should have been replaced by aggregate in the optimizer")
case logical.Repartition(numPartitions, shuffle, child) =>
execution.Repartition(numPartitions, shuffle, planLater(child)) :: Nil
case logical.SortPartitions(sortExprs, child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,37 +230,6 @@ case class ExternalSort(
override def outputOrdering: Seq[SortOrder] = sortOrder
}

/**
* :: DeveloperApi ::
* Computes the set of distinct input rows using a HashSet.
* @param partial when true the distinct operation is performed partially, per partition, without
* shuffling the data.
* @param child the input query plan.
*/
@DeveloperApi
case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output

override def requiredChildDistribution: Seq[Distribution] =
if (partial) UnspecifiedDistribution :: Nil else ClusteredDistribution(child.output) :: Nil

protected override def doExecute(): RDD[Row] = {
child.execute().mapPartitions { iter =>
val hashSet = new scala.collection.mutable.HashSet[Row]()

var currentRow: Row = null
while (iter.hasNext) {
currentRow = iter.next()
if (!hashSet.contains(currentRow)) {
hashSet.add(currentRow.copy())
}
}

hashSet.iterator
}
}
}

/**
* :: DeveloperApi ::
* Return a new RDD that has exactly `numPartitions` partitions.
Expand Down