diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index abd38f2f9d940..6b06cf13262d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1822,7 +1822,7 @@ class Analyzer(override val catalogManager: CatalogManager) val newOrders = orders map { case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) => if (index > 0 && index <= child.output.size) { - SortOrder(child.output(index - 1), direction, nullOrdering, Set.empty) + SortOrder(child.output(index - 1), direction, nullOrdering, Seq.empty) } else { s.failAnalysis( s"ORDER BY position $index is not in select list " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 89cf97e76d798..5303aacef4271 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -131,9 +131,9 @@ package object dsl { } def asc: SortOrder = SortOrder(expr, Ascending) - def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty) + def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Seq.empty) def desc: SortOrder = SortOrder(expr, Descending) - def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Set.empty) + def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Seq.empty) def as(alias: String): NamedExpression = Alias(expr, alias)() def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 54259e713accd..d9923b5d022e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -63,8 +63,10 @@ case class SortOrder( child: Expression, direction: SortDirection, nullOrdering: NullOrdering, - sameOrderExpressions: Set[Expression]) - extends UnaryExpression with Unevaluable { + sameOrderExpressions: Seq[Expression]) + extends Expression with Unevaluable { + + override def children: Seq[Expression] = child +: sameOrderExpressions override def checkInputDataTypes(): TypeCheckResult = { if (RowOrdering.isOrderable(dataType)) { @@ -83,7 +85,7 @@ case class SortOrder( def isAscending: Boolean = direction == Ascending def satisfies(required: SortOrder): Boolean = { - (sameOrderExpressions + child).exists(required.child.semanticEquals) && + children.exists(required.child.semanticEquals) && direction == required.direction && nullOrdering == required.nullOrdering } } @@ -92,7 +94,7 @@ object SortOrder { def apply( child: Expression, direction: SortDirection, - sameOrderExpressions: Set[Expression] = Set.empty): SortOrder = { + sameOrderExpressions: Seq[Expression] = Seq.empty): SortOrder = { new SortOrder(child, direction, direction.defaultNullOrdering, sameOrderExpressions) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ff8b56f0b724b..c4be7e86fa5e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1905,7 +1905,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } else { direction.defaultNullOrdering } - SortOrder(expression(ctx.expression), direction, nullOrdering, Set.empty) + SortOrder(expression(ctx.expression), direction, nullOrdering, Seq.empty) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b3e403ffa7382..95134d9111593 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1228,7 +1228,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Set.empty) } + def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Seq.empty) } /** * Returns a sort expression based on the descending order of the column, @@ -1244,7 +1244,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Set.empty) } + def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Seq.empty) } /** * Returns a sort expression based on ascending order of the column. @@ -1275,7 +1275,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Set.empty) } + def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Seq.empty) } /** * Returns a sort expression based on ascending order of the column, @@ -1291,7 +1291,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Set.empty) } + def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Seq.empty) } /** * Prints the expression to the console for debugging purposes. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala index 3ba8745be995f..3cbe1654ea2cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala @@ -65,11 +65,7 @@ trait AliasAwareOutputOrdering extends AliasAwareOutputExpression { final override def outputOrdering: Seq[SortOrder] = { if (hasAlias) { - orderingExpressions.map { sortOrder => - val newSortOrder = normalizeExpression(sortOrder).asInstanceOf[SortOrder] - val newSameOrderExpressions = newSortOrder.sameOrderExpressions.map(normalizeExpression) - newSortOrder.copy(sameOrderExpressions = newSameOrderExpressions) - } + orderingExpressions.map(normalizeExpression(_).asInstanceOf[SortOrder]) } else { orderingExpressions } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 6e59ad07d7168..eabbdc8ed3243 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -68,9 +68,9 @@ case class SortMergeJoinExec( val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering) val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering) leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) => - // Also add the right key and its `sameOrderExpressions` - SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + rKey.child ++ rKey - .sameOrderExpressions) + // Also add expressions from right side sort order + val sameOrderExpressions = ExpressionSet(lKey.sameOrderExpressions ++ rKey.children) + SortOrder(lKey.child, Ascending, sameOrderExpressions.toSeq) } // For left and right outer joins, the output is ordered by the streamed input's join keys. case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering) @@ -96,7 +96,8 @@ case class SortMergeJoinExec( val requiredOrdering = requiredOrders(keys) if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) { keys.zip(childOutputOrdering).map { case (key, childOrder) => - SortOrder(key, Ascending, childOrder.sameOrderExpressions + childOrder.child - key) + val sameOrderExpressionsSet = ExpressionSet(childOrder.children) - key + SortOrder(key, Ascending, sameOrderExpressionsSet.toSeq) } } else { requiredOrdering diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 6de81cc414d7d..5e30f846307ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -1090,6 +1090,32 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } } + test("sort order doesn't have repeated expressions") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTempView("t1", "t2") { + spark.range(10).repartition($"id").createTempView("t1") + spark.range(20).repartition($"id").createTempView("t2") + val planned = sql( + """ + | SELECT t12.id, t1.id + | FROM (SELECT t1.id FROM t1, t2 WHERE t1.id * 2 = t2.id) t12, t1 + | where 2 * t12.id = t1.id + """.stripMargin).queryExecution.executedPlan + + // t12 is already sorted on `t1.id * 2`. and we need to sort it on `2 * t12.id` + // for 2nd join. So sorting on t12 can be avoided + val sortNodes = planned.collect { case s: SortExec => s } + assert(sortNodes.size == 3) + val outputOrdering = planned.outputOrdering + assert(outputOrdering.size == 1) + // Sort order should have 3 childrens, not 4. This is because t1.id*2 and 2*t1.id are same + assert(outputOrdering.head.children.size == 3) + assert(outputOrdering.head.children.count(_.isInstanceOf[AttributeReference]) == 2) + assert(outputOrdering.head.children.count(_.isInstanceOf[Multiply]) == 1) + } + } + } + test("aliases to expressions should not be replaced") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { withTempView("df1", "df2") {