Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -1822,7 +1822,7 @@ class Analyzer(override val catalogManager: CatalogManager)
val newOrders = orders map {
case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) =>
if (index > 0 && index <= child.output.size) {
SortOrder(child.output(index - 1), direction, nullOrdering, Set.empty)
SortOrder(child.output(index - 1), direction, nullOrdering, Seq.empty)
} else {
s.failAnalysis(
s"ORDER BY position $index is not in select list " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ package object dsl {
}

def asc: SortOrder = SortOrder(expr, Ascending)
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty)
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Seq.empty)
def desc: SortOrder = SortOrder(expr, Descending)
def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Set.empty)
def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Seq.empty)
def as(alias: String): NamedExpression = Alias(expr, alias)()
def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ case class SortOrder(
child: Expression,
direction: SortDirection,
nullOrdering: NullOrdering,
sameOrderExpressions: Set[Expression])
extends UnaryExpression with Unevaluable {
sameOrderExpressions: Seq[Expression])
extends Expression with Unevaluable {

override def children: Seq[Expression] = child +: sameOrderExpressions

override def checkInputDataTypes(): TypeCheckResult = {
if (RowOrdering.isOrderable(dataType)) {
Expand All @@ -83,7 +85,7 @@ case class SortOrder(
def isAscending: Boolean = direction == Ascending

def satisfies(required: SortOrder): Boolean = {
(sameOrderExpressions + child).exists(required.child.semanticEquals) &&
children.exists(required.child.semanticEquals) &&
direction == required.direction && nullOrdering == required.nullOrdering
}
}
Expand All @@ -92,7 +94,7 @@ object SortOrder {
def apply(
child: Expression,
direction: SortDirection,
sameOrderExpressions: Set[Expression] = Set.empty): SortOrder = {
sameOrderExpressions: Seq[Expression] = Seq.empty): SortOrder = {
new SortOrder(child, direction, direction.defaultNullOrdering, sameOrderExpressions)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1905,7 +1905,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
} else {
direction.defaultNullOrdering
}
SortOrder(expression(ctx.expression), direction, nullOrdering, Set.empty)
SortOrder(expression(ctx.expression), direction, nullOrdering, Seq.empty)
}

/**
Expand Down
8 changes: 4 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1228,7 +1228,7 @@ class Column(val expr: Expression) extends Logging {
* @group expr_ops
* @since 2.1.0
*/
def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Set.empty) }
def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Seq.empty) }

/**
* Returns a sort expression based on the descending order of the column,
Expand All @@ -1244,7 +1244,7 @@ class Column(val expr: Expression) extends Logging {
* @group expr_ops
* @since 2.1.0
*/
def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Set.empty) }
def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Seq.empty) }

/**
* Returns a sort expression based on ascending order of the column.
Expand Down Expand Up @@ -1275,7 +1275,7 @@ class Column(val expr: Expression) extends Logging {
* @group expr_ops
* @since 2.1.0
*/
def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Set.empty) }
def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Seq.empty) }

/**
* Returns a sort expression based on ascending order of the column,
Expand All @@ -1291,7 +1291,7 @@ class Column(val expr: Expression) extends Logging {
* @group expr_ops
* @since 2.1.0
*/
def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Set.empty) }
def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Seq.empty) }

/**
* Prints the expression to the console for debugging purposes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,7 @@ trait AliasAwareOutputOrdering extends AliasAwareOutputExpression {

final override def outputOrdering: Seq[SortOrder] = {
if (hasAlias) {
orderingExpressions.map { sortOrder =>
val newSortOrder = normalizeExpression(sortOrder).asInstanceOf[SortOrder]
val newSameOrderExpressions = newSortOrder.sameOrderExpressions.map(normalizeExpression)
newSortOrder.copy(sameOrderExpressions = newSameOrderExpressions)
}
orderingExpressions.map(normalizeExpression(_).asInstanceOf[SortOrder])
} else {
orderingExpressions
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ case class SortMergeJoinExec(
val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering)
val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering)
leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) =>
// Also add the right key and its `sameOrderExpressions`
SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + rKey.child ++ rKey
.sameOrderExpressions)
// Also add expressions from right side sort order
val sameOrderExpressions = ExpressionSet(lKey.sameOrderExpressions ++ rKey.children)
SortOrder(lKey.child, Ascending, sameOrderExpressions.toSeq)
}
// For left and right outer joins, the output is ordered by the streamed input's join keys.
case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering)
Expand All @@ -96,7 +96,8 @@ case class SortMergeJoinExec(
val requiredOrdering = requiredOrders(keys)
if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) {
keys.zip(childOutputOrdering).map { case (key, childOrder) =>
SortOrder(key, Ascending, childOrder.sameOrderExpressions + childOrder.child - key)
val sameOrderExpressionsSet = ExpressionSet(childOrder.children) - key
SortOrder(key, Ascending, sameOrderExpressionsSet.toSeq)
}
} else {
requiredOrdering
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,32 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
}
}

test("sort order doesn't have repeated expressions") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTempView("t1", "t2") {
spark.range(10).repartition($"id").createTempView("t1")
spark.range(20).repartition($"id").createTempView("t2")
val planned = sql(
"""
| SELECT t12.id, t1.id
| FROM (SELECT t1.id FROM t1, t2 WHERE t1.id * 2 = t2.id) t12, t1
| where 2 * t12.id = t1.id
""".stripMargin).queryExecution.executedPlan

// t12 is already sorted on `t1.id * 2`. and we need to sort it on `2 * t12.id`
// for 2nd join. So sorting on t12 can be avoided
val sortNodes = planned.collect { case s: SortExec => s }
assert(sortNodes.size == 3)
val outputOrdering = planned.outputOrdering
assert(outputOrdering.size == 1)
// Sort order should have 3 childrens, not 4. This is because t1.id*2 and 2*t1.id are same
assert(outputOrdering.head.children.size == 3)
assert(outputOrdering.head.children.count(_.isInstanceOf[AttributeReference]) == 2)
assert(outputOrdering.head.children.count(_.isInstanceOf[Multiply]) == 1)
}
}
}

test("aliases to expressions should not be replaced") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTempView("df1", "df2") {
Expand Down