diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2cc27d82f7d2..d2d8ce432aa6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -164,10 +164,20 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) * Optimize all the subqueries inside expression. */ object OptimizeSubqueries extends Rule[LogicalPlan] { + private def removeTopLevelSort(plan: LogicalPlan): LogicalPlan = { + plan match { + case Sort(_, _, child) => child + case Project(fields, child) => Project(fields, removeTopLevelSort(child)) + case other => other + } + } def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case s: SubqueryExpression => val Subquery(newPlan) = Optimizer.this.execute(Subquery(s.plan)) - s.withNewPlan(newPlan) + // At this point we have an optimized subquery plan that we are going to attach + // to this subquery expression. Here we can safely remove any top level sort + // in the plan as tuples produced by a subquery are un-ordered. + s.withNewPlan(removeTopLevelSort(newPlan)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index acef62d81ee1..cbffed994bb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.plans.logical.Join +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort} import org.apache.spark.sql.test.SharedSQLContext class SubquerySuite extends QueryTest with SharedSQLContext { @@ -970,4 +973,299 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row("3", "b") :: Row("4", "b") :: Nil) } } + + private def getNumSortsInQuery(query: String): Int = { + val plan = sql(query).queryExecution.optimizedPlan + getNumSorts(plan) + getSubqueryExpressions(plan).map{s => getNumSorts(s.plan)}.sum + } + + private def getSubqueryExpressions(plan: LogicalPlan): Seq[SubqueryExpression] = { + val subqueryExpressions = ArrayBuffer.empty[SubqueryExpression] + plan transformAllExpressions { + case s: SubqueryExpression => + subqueryExpressions ++= (getSubqueryExpressions(s.plan) :+ s) + s + } + subqueryExpressions + } + + private def getNumSorts(plan: LogicalPlan): Int = { + plan.collect { case s: Sort => s }.size + } + + test("SPARK-23957 Remove redundant sort from subquery plan(in subquery)") { + withTempView("t1", "t2", "t3") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") + + // Simple order by + val query1 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 ORDER BY c1) + """.stripMargin + assert(getNumSortsInQuery(query1) == 0) + + // Nested order bys + val query2 = + """ + |SELECT c1 + |FROM t1 + |WHERE c1 IN (SELECT c1 + | FROM (SELECT * + | FROM t2 + | ORDER BY c2) + | ORDER BY c1) + """.stripMargin + assert(getNumSortsInQuery(query2) == 0) + + + // nested IN + val query3 = + """ + |SELECT c1 + |FROM t1 + |WHERE c1 IN (SELECT c1 + | FROM t2 + | WHERE c1 IN (SELECT c1 + | FROM t3 + | WHERE c1 = 1 + | ORDER BY c3) + | ORDER BY c2) + """.stripMargin + assert(getNumSortsInQuery(query3) == 0) + + // Complex subplan and multiple sorts + val query4 = + """ + |SELECT c1 + |FROM t1 + |WHERE c1 IN (SELECT c1 + | FROM (SELECT c1, c2, count(*) + | FROM t2 + | GROUP BY c1, c2 + | HAVING count(*) > 0 + | ORDER BY c2) + | ORDER BY c1) + """.stripMargin + assert(getNumSortsInQuery(query4) == 0) + + // Join in subplan + val query5 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (SELECT t2.c1 FROM t2, t3 + | WHERE t2.c1 = t3.c1 + | ORDER BY t2.c1) + """.stripMargin + assert(getNumSortsInQuery(query5) == 0) + + val query6 = + """ + |SELECT c1 + |FROM t1 + |WHERE (c1, c2) IN (SELECT c1, max(c2) + | FROM (SELECT c1, c2, count(*) + | FROM t2 + | GROUP BY c1, c2 + | HAVING count(*) > 0 + | ORDER BY c2) + | GROUP BY c1 + | HAVING max(c2) > 0 + | ORDER BY c1) + """.stripMargin + // The rule to remove redundant sorts is not able to remove the inner sort under + // an Aggregate operator. We only remove the top level sort. + assert(getNumSortsInQuery(query6) == 1) + + // Cases when sort is not removed from the plan + // Limit on top of sort + val query7 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 ORDER BY c1 limit 1) + """.stripMargin + assert(getNumSortsInQuery(query7) == 1) + + // Sort below a set operations (intersect, union) + val query8 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (( + | SELECT c1 FROM t2 + | ORDER BY c1 + | ) + | UNION + | ( + | SELECT c1 FROM t2 + | ORDER BY c1 + | )) + """.stripMargin + assert(getNumSortsInQuery(query8) == 2) + } + } + + test("SPARK-23957 Remove redundant sort from subquery plan(exists subquery)") { + withTempView("t1", "t2", "t3") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") + + // Simple order by exists correlated + val query1 = + """ + |SELECT c1 FROM t1 + |WHERE + |EXISTS (SELECT t2.c1 FROM t2 WHERE t1.c1 = t2.c1 ORDER BY t2.c1) + """.stripMargin + assert(getNumSortsInQuery(query1) == 0) + + // Nested order by and correlated. + val query2 = + """ + |SELECT c1 + |FROM t1 + |WHERE EXISTS (SELECT c1 + | FROM (SELECT * + | FROM t2 + | WHERE t2.c1 = t1.c1 + | ORDER BY t2.c2) t2 + | ORDER BY t2.c1) + """.stripMargin + assert(getNumSortsInQuery(query2) == 0) + + // nested EXISTS + val query3 = + """ + |SELECT c1 + |FROM t1 + |WHERE EXISTS (SELECT c1 + | FROM t2 + | WHERE EXISTS (SELECT c1 + | FROM t3 + | WHERE t3.c1 = t2.c1 + | ORDER BY c3) + | AND t2.c1 = t1.c1 + | ORDER BY c2) + """.stripMargin + assert(getNumSortsInQuery(query3) == 0) + + // Cases when sort is not removed from the plan + // Limit on top of sort + val query4 = + """ + |SELECT c1 FROM t1 + |WHERE + |EXISTS (SELECT t2.c1 FROM t2 WHERE t2.c1 = 1 ORDER BY t2.c1 limit 1) + """.stripMargin + assert(getNumSortsInQuery(query4) == 1) + + // Sort below a set operations (intersect, union) + val query5 = + """ + |SELECT c1 FROM t1 + |WHERE + |EXISTS (( + | SELECT c1 FROM t2 + | WHERE t2.c1 = 1 + | ORDER BY t2.c1 + | ) + | UNION + | ( + | SELECT c1 FROM t2 + | WHERE t2.c1 = 2 + | ORDER BY t2.c1 + | )) + """.stripMargin + assert(getNumSortsInQuery(query5) == 2) + } + } + + test("SPARK-23957 Remove redundant sort from subquery plan(scalar subquery)") { + withTempView("t1", "t2", "t3") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") + + // Two scalar subqueries in OR + val query1 = + """ + |SELECT * FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | ORDER BY max(t2.c1)) + |OR c2 = (SELECT min(t3.c2) + | FROM t3 + | WHERE t3.c1 = 1 + | ORDER BY min(t3.c2)) + """.stripMargin + assert(getNumSortsInQuery(query1) == 0) + + // scalar subquery - groupby and having + val query2 = + """ + |SELECT * + |FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | GROUP BY t2.c1 + | HAVING count(*) >= 1 + | ORDER BY max(t2.c1)) + """.stripMargin + assert(getNumSortsInQuery(query2) == 0) + + // nested scalar subquery + val query3 = + """ + |SELECT * + |FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | WHERE c1 = (SELECT max(t3.c1) + | FROM t3 + | WHERE t3.c1 = 1 + | GROUP BY t3.c1 + | ORDER BY max(t3.c1) + | ) + | GROUP BY t2.c1 + | HAVING count(*) >= 1 + | ORDER BY max(t2.c1)) + """.stripMargin + assert(getNumSortsInQuery(query3) == 0) + + // Scalar subquery in projection + val query4 = + """ + |SELECT (SELECT min(c1) from t1 group by c1 order by c1) + |FROM t1 + |WHERE t1.c1 = 1 + """.stripMargin + assert(getNumSortsInQuery(query4) == 0) + + // Limit on top of sort prevents it from being pruned. + val query5 = + """ + |SELECT * + |FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | WHERE c1 = (SELECT max(t3.c1) + | FROM t3 + | WHERE t3.c1 = 1 + | GROUP BY t3.c1 + | ORDER BY max(t3.c1) + | ) + | GROUP BY t2.c1 + | HAVING count(*) >= 1 + | ORDER BY max(t2.c1) + | LIMIT 1) + """.stripMargin + assert(getNumSortsInQuery(query5) == 1) + } + } }