Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -966,9 +966,9 @@ class Analyzer(
case s @ Sort(orders, global, child)
if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
val newOrders = orders map {
case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering) =>
case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) =>
if (index > 0 && index <= child.output.size) {
SortOrder(child.output(index - 1), direction, nullOrdering)
SortOrder(child.output(index - 1), direction, nullOrdering, Set.empty)
} else {
s.failAnalysis(
s"ORDER BY position $index is not in select list " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan]
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
val newOrders = s.order.map {
case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) =>
case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) =>
val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
withOrigin(order.origin)(order.copy(child = newOrdinal))
case other => other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ package object dsl {
def cast(to: DataType): Expression = Cast(expr, to)

def asc: SortOrder = SortOrder(expr, Ascending)
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast)
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty)
def desc: SortOrder = SortOrder(expr, Descending)
def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst)
def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Set.empty)
def as(alias: String): NamedExpression = Alias(expr, alias)()
def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,15 @@ case object NullsLast extends NullOrdering{
/**
* An expression that can be used to sort a tuple. This class extends expression primarily so that
* transformations over expression will descend into its child.
* `sameOrderExpressions` is a set of expressions with the same sort order as the child. It is
* derived from equivalence relation in an operator, e.g. left/right keys of an inner sort merge
* join.
*/
case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: NullOrdering)
case class SortOrder(
child: Expression,
direction: SortDirection,
nullOrdering: NullOrdering,
sameOrderExpressions: Set[Expression])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we don't need this, and we can rely on EqualTo constraint to infer this information. Unfortunately the constraint only exists in logical plan, so we can't find a better solution for this case. cc @gatorsmile do you have a better idea?

extends UnaryExpression with Unevaluable {

/** Sort order is not foldable because we don't have an eval for it. */
Expand All @@ -75,11 +82,19 @@ case class SortOrder(child: Expression, direction: SortDirection, nullOrdering:
override def sql: String = child.sql + " " + direction.sql + " " + nullOrdering.sql

def isAscending: Boolean = direction == Ascending

def satisfies(required: SortOrder): Boolean = {
(sameOrderExpressions + child).exists(required.child.semanticEquals) &&
direction == required.direction && nullOrdering == required.nullOrdering
}
}

object SortOrder {
def apply(child: Expression, direction: SortDirection): SortOrder = {
new SortOrder(child, direction, direction.defaultNullOrdering)
def apply(
child: Expression,
direction: SortDirection,
sameOrderExpressions: Set[Expression] = Set.empty): SortOrder = {
new SortOrder(child, direction, direction.defaultNullOrdering, sameOrderExpressions)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
} else {
direction.defaultNullOrdering
}
SortOrder(expression(ctx.expression), direction, nullOrdering)
SortOrder(expression(ctx.expression), direction, nullOrdering, Set.empty)
}

/**
Expand Down
8 changes: 4 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ class Column(val expr: Expression) extends Logging {
* @group expr_ops
* @since 2.1.0
*/
def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst) }
def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Set.empty) }

/**
* Returns a descending ordering used in sorting, where null values appear after non-null values.
Expand All @@ -1052,7 +1052,7 @@ class Column(val expr: Expression) extends Logging {
* @group expr_ops
* @since 2.1.0
*/
def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast) }
def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Set.empty) }

/**
* Returns an ascending ordering used in sorting.
Expand Down Expand Up @@ -1082,7 +1082,7 @@ class Column(val expr: Expression) extends Logging {
* @group expr_ops
* @since 2.1.0
*/
def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst) }
def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Set.empty) }

/**
* Returns an ordering used in sorting, where null values appear after non-null values.
Expand All @@ -1097,7 +1097,7 @@ class Column(val expr: Expression) extends Logging {
* @group expr_ops
* @since 2.1.0
*/
def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast) }
def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Set.empty) }

/**
* Prints the expression to the console for debugging purpose.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
} else {
requiredOrdering.zip(child.outputOrdering).forall {
case (requiredOrder, childOutputOrder) =>
requiredOrder.semanticEquals(childOutputOrder)
childOutputOrder.satisfies(requiredOrder)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,37 @@ case class SortMergeJoinExec(
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

override def outputOrdering: Seq[SortOrder] = joinType match {
// For inner join, orders of both sides keys should be kept.
case Inner =>
val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering)
val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering)
leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) =>
// Also add the right key and its `sameOrderExpressions`
SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + rKey.child ++ rKey
.sameOrderExpressions)
}
// For left and right outer joins, the output is ordered by the streamed input's join keys.
case LeftOuter => requiredOrders(leftKeys)
case RightOuter => requiredOrders(rightKeys)
case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering)
case RightOuter => getKeyOrdering(rightKeys, right.outputOrdering)
// There are null rows in both streams, so there is no order.
case FullOuter => Nil
case _: InnerLike | LeftExistence(_) => requiredOrders(leftKeys)
case LeftExistence(_) => getKeyOrdering(leftKeys, left.outputOrdering)
case x =>
throw new IllegalArgumentException(
s"${getClass.getSimpleName} should not take $x as the JoinType")
}

/**
* For SMJ, child's output must have been sorted on key or expressions with the same order as
* key, so we can get ordering for key from child's output ordering.
*/
private def getKeyOrdering(keys: Seq[Expression], childOutputOrdering: Seq[SortOrder])
: Seq[SortOrder] = {
keys.zip(childOutputOrdering).map { case (key, childOrder) =>
SortOrder(key, Ascending, childOrder.sameOrderExpressions + childOrder.child - key)
}
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] =
requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,14 +477,18 @@ class PlannerSuite extends SharedSQLContext {

private val exprA = Literal(1)
private val exprB = Literal(2)
private val exprC = Literal(3)
private val orderingA = SortOrder(exprA, Ascending)
private val orderingB = SortOrder(exprB, Ascending)
private val orderingC = SortOrder(exprC, Ascending)
private val planA = DummySparkPlan(outputOrdering = Seq(orderingA),
outputPartitioning = HashPartitioning(exprA :: Nil, 5))
private val planB = DummySparkPlan(outputOrdering = Seq(orderingB),
outputPartitioning = HashPartitioning(exprB :: Nil, 5))
private val planC = DummySparkPlan(outputOrdering = Seq(orderingC),
outputPartitioning = HashPartitioning(exprC :: Nil, 5))

assert(orderingA != orderingB)
assert(orderingA != orderingB && orderingA != orderingC && orderingB != orderingC)

private def assertSortRequirementsAreSatisfied(
childPlan: SparkPlan,
Expand All @@ -508,6 +512,30 @@ class PlannerSuite extends SharedSQLContext {
}
}

test("EnsureRequirements skips sort when either side of join keys is required after inner SMJ") {
val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB)
// Both left and right keys should be sorted after the SMJ.
Seq(orderingA, orderingB).foreach { ordering =>
assertSortRequirementsAreSatisfied(
childPlan = innerSmj,
requiredOrdering = Seq(ordering),
shouldHaveSort = false)
}
}

test("EnsureRequirements skips sort when key order of a parent SMJ is propagated from its " +
"child SMJ") {
val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB)
val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, Inner, None, childSmj, planC)
// After the second SMJ, exprA, exprB and exprC should all be sorted.
Seq(orderingA, orderingB, orderingC).foreach { ordering =>
assertSortRequirementsAreSatisfied(
childPlan = parentSmj,
requiredOrdering = Seq(ordering),
shouldHaveSort = false)
}
}

test("EnsureRequirements for sort operator after left outer sort merge join") {
// Only left key is sorted after left outer SMJ (thus doesn't need a sort).
val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, None, planA, planB)
Expand Down