Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,20 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
* Optimize all the subqueries inside expression.
*/
object OptimizeSubqueries extends Rule[LogicalPlan] {
private def removeTopLevelSort(plan: LogicalPlan): LogicalPlan = {
plan match {
case Sort(_, _, child) => child
case Project(fields, child) => Project(fields, removeTopLevelSort(child))
case other => other
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case s: SubqueryExpression =>
val Subquery(newPlan) = Optimizer.this.execute(Subquery(s.plan))
s.withNewPlan(newPlan)
// At this point we have an optimized subquery plan that we are going to attach
// to this subquery expression. Here we can safely remove any top level sort
// in the plan as tuples produced by a subquery are un-ordered.
s.withNewPlan(removeTopLevelSort(newPlan))
}
}

Expand Down
300 changes: 299 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.plans.logical.Join
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort}
import org.apache.spark.sql.test.SharedSQLContext

class SubquerySuite extends QueryTest with SharedSQLContext {
Expand Down Expand Up @@ -970,4 +973,299 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
Row("3", "b") :: Row("4", "b") :: Nil)
}
}

private def getNumSortsInQuery(query: String): Int = {
val plan = sql(query).queryExecution.optimizedPlan
getNumSorts(plan) + getSubqueryExpressions(plan).map{s => getNumSorts(s.plan)}.sum
}

private def getSubqueryExpressions(plan: LogicalPlan): Seq[SubqueryExpression] = {
val subqueryExpressions = ArrayBuffer.empty[SubqueryExpression]
plan transformAllExpressions {
case s: SubqueryExpression =>
subqueryExpressions ++= (getSubqueryExpressions(s.plan) :+ s)
s
}
subqueryExpressions
}

private def getNumSorts(plan: LogicalPlan): Int = {
plan.collect { case s: Sort => s }.size
}

test("SPARK-23957 Remove redundant sort from subquery plan(in subquery)") {
withTempView("t1", "t2", "t3") {
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3")

// Simple order by
val query1 =
"""
|SELECT c1 FROM t1
|WHERE
|c1 IN (SELECT c1 FROM t2 ORDER BY c1)
""".stripMargin
assert(getNumSortsInQuery(query1) == 0)

// Nested order bys
val query2 =
"""
|SELECT c1
|FROM t1
|WHERE c1 IN (SELECT c1
| FROM (SELECT *
| FROM t2
| ORDER BY c2)
| ORDER BY c1)
""".stripMargin
assert(getNumSortsInQuery(query2) == 0)


// nested IN
val query3 =
"""
|SELECT c1
|FROM t1
|WHERE c1 IN (SELECT c1
| FROM t2
| WHERE c1 IN (SELECT c1
| FROM t3
| WHERE c1 = 1
| ORDER BY c3)
| ORDER BY c2)
""".stripMargin
assert(getNumSortsInQuery(query3) == 0)

// Complex subplan and multiple sorts
val query4 =
"""
|SELECT c1
|FROM t1
|WHERE c1 IN (SELECT c1
| FROM (SELECT c1, c2, count(*)
| FROM t2
| GROUP BY c1, c2
| HAVING count(*) > 0
| ORDER BY c2)
| ORDER BY c1)
""".stripMargin
assert(getNumSortsInQuery(query4) == 0)

// Join in subplan
val query5 =
"""
|SELECT c1 FROM t1
|WHERE
|c1 IN (SELECT t2.c1 FROM t2, t3
| WHERE t2.c1 = t3.c1
| ORDER BY t2.c1)
""".stripMargin
assert(getNumSortsInQuery(query5) == 0)

val query6 =
"""
|SELECT c1
|FROM t1
|WHERE (c1, c2) IN (SELECT c1, max(c2)
| FROM (SELECT c1, c2, count(*)
| FROM t2
| GROUP BY c1, c2
| HAVING count(*) > 0
| ORDER BY c2)
| GROUP BY c1
| HAVING max(c2) > 0
| ORDER BY c1)
""".stripMargin
// The rule to remove redundant sorts is not able to remove the inner sort under
// an Aggregate operator. We only remove the top level sort.
assert(getNumSortsInQuery(query6) == 1)

// Cases when sort is not removed from the plan
// Limit on top of sort
val query7 =
"""
|SELECT c1 FROM t1
|WHERE
|c1 IN (SELECT c1 FROM t2 ORDER BY c1 limit 1)
""".stripMargin
assert(getNumSortsInQuery(query7) == 1)

// Sort below a set operations (intersect, union)
val query8 =
"""
|SELECT c1 FROM t1
|WHERE
|c1 IN ((
| SELECT c1 FROM t2
| ORDER BY c1
| )
| UNION
| (
| SELECT c1 FROM t2
| ORDER BY c1
| ))
""".stripMargin
assert(getNumSortsInQuery(query8) == 2)
}
}

test("SPARK-23957 Remove redundant sort from subquery plan(exists subquery)") {
withTempView("t1", "t2", "t3") {
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3")

// Simple order by exists correlated
val query1 =
"""
|SELECT c1 FROM t1
|WHERE
|EXISTS (SELECT t2.c1 FROM t2 WHERE t1.c1 = t2.c1 ORDER BY t2.c1)
""".stripMargin
assert(getNumSortsInQuery(query1) == 0)

// Nested order by and correlated.
val query2 =
"""
|SELECT c1
|FROM t1
|WHERE EXISTS (SELECT c1
| FROM (SELECT *
| FROM t2
| WHERE t2.c1 = t1.c1
| ORDER BY t2.c2) t2
| ORDER BY t2.c1)
""".stripMargin
assert(getNumSortsInQuery(query2) == 0)

// nested EXISTS
val query3 =
"""
|SELECT c1
|FROM t1
|WHERE EXISTS (SELECT c1
| FROM t2
| WHERE EXISTS (SELECT c1
| FROM t3
| WHERE t3.c1 = t2.c1
| ORDER BY c3)
| AND t2.c1 = t1.c1
| ORDER BY c2)
""".stripMargin
assert(getNumSortsInQuery(query3) == 0)

// Cases when sort is not removed from the plan
// Limit on top of sort
val query4 =
"""
|SELECT c1 FROM t1
|WHERE
|EXISTS (SELECT t2.c1 FROM t2 WHERE t2.c1 = 1 ORDER BY t2.c1 limit 1)
""".stripMargin
assert(getNumSortsInQuery(query4) == 1)

// Sort below a set operations (intersect, union)
val query5 =
"""
|SELECT c1 FROM t1
|WHERE
|EXISTS ((
| SELECT c1 FROM t2
| WHERE t2.c1 = 1
| ORDER BY t2.c1
| )
| UNION
| (
| SELECT c1 FROM t2
| WHERE t2.c1 = 2
| ORDER BY t2.c1
| ))
""".stripMargin
assert(getNumSortsInQuery(query5) == 2)
}
}

test("SPARK-23957 Remove redundant sort from subquery plan(scalar subquery)") {
withTempView("t1", "t2", "t3") {
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3")

// Two scalar subqueries in OR
val query1 =
"""
|SELECT * FROM t1
|WHERE c1 = (SELECT max(t2.c1)
| FROM t2
| ORDER BY max(t2.c1))
|OR c2 = (SELECT min(t3.c2)
| FROM t3
| WHERE t3.c1 = 1
| ORDER BY min(t3.c2))
""".stripMargin
assert(getNumSortsInQuery(query1) == 0)

// scalar subquery - groupby and having
val query2 =
"""
|SELECT *
|FROM t1
|WHERE c1 = (SELECT max(t2.c1)
| FROM t2
| GROUP BY t2.c1
| HAVING count(*) >= 1
| ORDER BY max(t2.c1))
""".stripMargin
assert(getNumSortsInQuery(query2) == 0)

// nested scalar subquery
val query3 =
"""
|SELECT *
|FROM t1
|WHERE c1 = (SELECT max(t2.c1)
| FROM t2
| WHERE c1 = (SELECT max(t3.c1)
| FROM t3
| WHERE t3.c1 = 1
| GROUP BY t3.c1
| ORDER BY max(t3.c1)
| )
| GROUP BY t2.c1
| HAVING count(*) >= 1
| ORDER BY max(t2.c1))
""".stripMargin
assert(getNumSortsInQuery(query3) == 0)

// Scalar subquery in projection
val query4 =
"""
|SELECT (SELECT min(c1) from t1 group by c1 order by c1)
|FROM t1
|WHERE t1.c1 = 1
""".stripMargin
assert(getNumSortsInQuery(query4) == 0)

// Limit on top of sort prevents it from being pruned.
val query5 =
"""
|SELECT *
|FROM t1
|WHERE c1 = (SELECT max(t2.c1)
| FROM t2
| WHERE c1 = (SELECT max(t3.c1)
| FROM t3
| WHERE t3.c1 = 1
| GROUP BY t3.c1
| ORDER BY max(t3.c1)
| )
| GROUP BY t2.c1
| HAVING count(*) >= 1
| ORDER BY max(t2.c1)
| LIMIT 1)
""".stripMargin
assert(getNumSortsInQuery(query5) == 1)
}
}
}