Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ package object dsl extends SQLConfHelper {

def asc: SortOrder = SortOrder(expr, Ascending)
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Seq.empty)
def const: SortOrder = SortOrder(expr, Constant)
def desc: SortOrder = SortOrder(expr, Descending)
def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Seq.empty)
def as(alias: String): NamedExpression = Alias(expr, alias)()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ case object Descending extends SortDirection {
override def defaultNullOrdering: NullOrdering = NullsLast
}

case object Constant extends SortDirection {
override def sql: String = "CONST"
override def defaultNullOrdering: NullOrdering = NullsFirst
}

case object NullsFirst extends NullOrdering {
override def sql: String = "NULLS FIRST"
}
Expand All @@ -69,8 +74,13 @@ case class SortOrder(

override def children: Seq[Expression] = child +: sameOrderExpressions

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForOrderingExpr(dataType, prettyName)
override def checkInputDataTypes(): TypeCheckResult = {
if (direction == Constant) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeUtils.checkForOrderingExpr(dataType, prettyName)
}
}

override def dataType: DataType = child.dataType
override def nullable: Boolean = child.nullable
Expand All @@ -81,8 +91,8 @@ case class SortOrder(
def isAscending: Boolean = direction == Ascending

def satisfies(required: SortOrder): Boolean = {
children.exists(required.child.semanticEquals) &&
direction == required.direction && nullOrdering == required.nullOrdering
children.exists(required.child.semanticEquals) && (direction == Constant ||
direction == required.direction && nullOrdering == required.nullOrdering)
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): SortOrder =
Expand All @@ -101,21 +111,38 @@ object SortOrder {
* Returns if a sequence of SortOrder satisfies another sequence of SortOrder.
*
* SortOrder sequence A satisfies SortOrder sequence B if and only if B is an equivalent of A
* or of A's prefix. Here are examples of ordering A satisfying ordering B:
* or of A's prefix, except for SortOrder in B that satisfies any constant SortOrder in A.
*
* Here are examples of ordering A satisfying ordering B:
* <ul>
* <li>ordering A is [x, y] and ordering B is [x]</li>
* <li>ordering A is [z(const), x, y] and ordering B is [x, z]</li>
* <li>ordering A is [x(sameOrderExpressions=x1)] and ordering B is [x1]</li>
* <li>ordering A is [x(sameOrderExpressions=x1), y] and ordering B is [x1]</li>
* </ul>
*/
def orderingSatisfies(ordering1: Seq[SortOrder], ordering2: Seq[SortOrder]): Boolean = {
if (ordering2.isEmpty) {
true
} else if (ordering2.length > ordering1.length) {
def orderingSatisfies(
providedOrdering: Seq[SortOrder], requiredOrdering: Seq[SortOrder]): Boolean = {
if (requiredOrdering.isEmpty) {
return true
}

val (constantProvidedOrdering, nonConstantProvidedOrdering) = providedOrdering.partition {
case SortOrder(_, Constant, _, _) => true
case SortOrder(child, _, _, _) => child.foldable
}

val effectiveRequiredOrdering = requiredOrdering.filterNot { requiredOrder =>
constantProvidedOrdering.exists { providedOrder =>
providedOrder.satisfies(requiredOrder)
}
}

if (effectiveRequiredOrdering.length > nonConstantProvidedOrdering.length) {
false
} else {
ordering2.zip(ordering1).forall {
case (o2, o1) => o1.satisfies(o2)
effectiveRequiredOrdering.zip(nonConstantProvidedOrdering).forall {
case (required, provided) => provided.satisfies(required)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Constant, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
import org.apache.spark.sql.catalyst.expressions.Cast.toSQLExpr
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.PhysicalDataType
Expand Down Expand Up @@ -199,6 +199,8 @@ case class Mode(
this.copy(child = child, reverseOpt = Some(true))
case SortOrder(child, Descending, _, _) =>
this.copy(child = child, reverseOpt = Some(false))
case SortOrder(child, Constant, _, _) =>
this.copy(child = child)
}
case _ => this
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ case class PercentileCont(left: Expression, right: Expression, reverse: Boolean
nodeName, 1, orderingWithinGroup.length)
}
orderingWithinGroup.head match {
case SortOrder(child, Ascending, _, _) => this.copy(left = child)
case SortOrder(child, Ascending | Constant, _, _) => this.copy(left = child)
case SortOrder(child, Descending, _, _) => this.copy(left = child, reverse = true)
}
}
Expand Down Expand Up @@ -440,7 +440,7 @@ case class PercentileDisc(
nodeName, 1, orderingWithinGroup.length)
}
orderingWithinGroup.head match {
case SortOrder(expr, Ascending, _, _) => this.copy(child = expr)
case SortOrder(expr, Ascending | Constant, _, _) => this.copy(child = expr)
case SortOrder(expr, Descending, _, _) => this.copy(child = expr, reverse = true)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,7 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
object EliminateSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(_.containsPattern(SORT)) {
case s @ Sort(orders, _, child, _) if orders.isEmpty || orders.exists(_.child.foldable) =>
val newOrders = orders.filterNot(_.child.foldable)
val newOrders = orders.filterNot(o => o.direction != Constant && o.child.foldable)
if (newOrders.isEmpty) {
child
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans
import scala.collection.mutable

import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Empty2Null, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Constant, Empty2Null, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down Expand Up @@ -128,6 +128,8 @@ trait AliasAwareQueryOutputOrdering[T <: QueryPlan[T]]
}
}
}
newOrdering.takeWhile(_.isDefined).flatten.toSeq
newOrdering.takeWhile(_.isDefined).flatten.toSeq ++ outputExpressions.collect {
case a @ Alias(child, _) if child.foldable => SortOrder(a.toAttribute, Constant)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,8 @@ case class Sort(
override def maxRowsPerPartition: Option[Long] = {
if (global) maxRows else child.maxRowsPerPartition
}
override def outputOrdering: Seq[SortOrder] = order
override def outputOrdering: Seq[SortOrder] =
order ++ child.outputOrdering.filter(_.direction == Constant)
final override val nodePatterns: Seq[TreePattern] = Seq(SORT)
override protected def withNewChildInternal(newChild: LogicalPlan): Sort = copy(child = newChild)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper {
val sortOrder = direction match {
case Ascending => BoundReference(0, dataType, nullable = true).asc
case Descending => BoundReference(0, dataType, nullable = true).desc
case Constant => BoundReference(0, dataType, nullable = true).const
}
val expectedCompareResult = direction match {
case Ascending => signum(expected)
case Ascending | Constant => signum(expected)
case Descending => -1 * signum(expected)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ case class SortExec(

override def output: Seq[Attribute] = child.output

override def outputOrdering: Seq[SortOrder] = sortOrder
override def outputOrdering: Seq[SortOrder] =
sortOrder ++ child.outputOrdering.filter(_.direction == Constant)

// sort performed is local within a given partition so will retain
// child operator's partitioning
Expand All @@ -73,15 +74,17 @@ case class SortExec(
* should make it public.
*/
def createSorter(): UnsafeExternalRowSorter = {
val effectiveSortOrder = sortOrder.filterNot(_.direction == Constant)

rowSorter = new ThreadLocal[UnsafeExternalRowSorter]()

val ordering = RowOrdering.create(sortOrder, output)

// The comparator for comparing prefix
val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)
val boundSortExpression = BindReferences.bindReference(effectiveSortOrder.head, output)
val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)

val canUseRadixSort = enableRadixSort && sortOrder.length == 1 &&
val canUseRadixSort = enableRadixSort && effectiveSortOrder.length == 1 &&
SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression)

// The generator for prefix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ object SortPrefixUtils {
PrefixComparators.STRING_DESC_NULLS_FIRST
case Descending =>
PrefixComparators.STRING_DESC
case Constant =>
NoOpPrefixComparator
}
}

Expand All @@ -76,6 +78,8 @@ object SortPrefixUtils {
PrefixComparators.BINARY_DESC_NULLS_FIRST
case Descending =>
PrefixComparators.BINARY_DESC
case Constant =>
NoOpPrefixComparator
}
}

Expand All @@ -89,6 +93,8 @@ object SortPrefixUtils {
PrefixComparators.LONG_DESC_NULLS_FIRST
case Descending =>
PrefixComparators.LONG_DESC
case Constant =>
NoOpPrefixComparator
}
}

Expand All @@ -102,6 +108,8 @@ object SortPrefixUtils {
PrefixComparators.DOUBLE_DESC_NULLS_FIRST
case Descending =>
PrefixComparators.DOUBLE_DESC
case Constant =>
NoOpPrefixComparator
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,7 @@ case class InMemoryRelation(
override def innerChildren: Seq[SparkPlan] = Seq(cachedPlan)

override def doCanonicalize(): logical.LogicalPlan =
copy(output = output.map(QueryPlan.normalizeExpressions(_, output)),
cacheBuilder,
outputOrdering)
withOutput(output.map(QueryPlan.normalizeExpressions(_, output)))

@transient val partitionStatistics = new PartitionStatistics(output)

Expand Down Expand Up @@ -469,8 +467,13 @@ case class InMemoryRelation(
}
}

def withOutput(newOutput: Seq[Attribute]): InMemoryRelation =
InMemoryRelation(newOutput, cacheBuilder, outputOrdering, statsOfPlanToCache)
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
val map = AttributeMap(output.zip(newOutput))
val newOutputOrdering = outputOrdering
.map(_.transform { case a: Attribute => map(a) })
.asInstanceOf[Seq[SortOrder]]
InMemoryRelation(newOutput, cacheBuilder, newOutputOrdering, statsOfPlanToCache)
}

override def newInstance(): this.type = {
InMemoryRelation(
Expand All @@ -487,6 +490,12 @@ case class InMemoryRelation(
cloned
}

override def makeCopy(newArgs: Array[AnyRef]): LogicalPlan = {
val copied = super.makeCopy(newArgs).asInstanceOf[InMemoryRelation]
copied.statsOfPlanToCache = this.statsOfPlanToCache
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is a hidden bug just exposed by this change.

copied
}

override def simpleString(maxFields: Int): String =
s"InMemoryRelation [${truncatedString(output, ", ", maxFields)}], ${cacheBuilder.storageLevel}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.jdk.CollectionConverters._

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.PREDICATES
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -827,6 +828,8 @@ object DataSourceStrategy
val directionV2 = directionV1 match {
case Ascending => SortDirection.ASCENDING
case Descending => SortDirection.DESCENDING
case Constant =>
throw SparkException.internalError(s"Unexpected catalyst sort direction $Constant")
}
val nullOrderingV2 = nullOrderingV1 match {
case NullsFirst => NullOrdering.NULLS_FIRST
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeMap, AttributeSet, BitwiseAnd, Empty2Null, Expression, HiveHash, Literal, NamedExpression, Pmod, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeMap, AttributeSet, BitwiseAnd, Constant, Empty2Null, Expression, HiveHash, Literal, NamedExpression, Pmod, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -199,13 +199,27 @@ object V1WritesUtils {
expressions.exists(_.exists(_.isInstanceOf[Empty2Null]))
}

// SortOrder sequence A (outputOrdering) satisfies SortOrder sequence B (requiredOrdering)
// if and only if B is an equivalent of A or of A's prefix, except for SortOrder in B that
// satisfies any constant SortOrder in A.
def isOrderingMatched(
requiredOrdering: Seq[Expression],
outputOrdering: Seq[SortOrder]): Boolean = {
if (requiredOrdering.length > outputOrdering.length) {
val (constantOutputOrdering, nonConstantOutputOrdering) = outputOrdering.partition {
case SortOrder(_, Constant, _, _) => true
case SortOrder(child, _, _, _) => child.foldable
}

val effectiveRequiredOrdering = requiredOrdering.filterNot { requiredOrder =>
constantOutputOrdering.exists { outputOrder =>
outputOrder.satisfies(outputOrder.copy(child = requiredOrder))
}
}

if (effectiveRequiredOrdering.length > nonConstantOutputOrdering.length) {
false
} else {
requiredOrdering.zip(outputOrdering).forall {
effectiveRequiredOrdering.zip(nonConstantOutputOrdering).forall {
case (requiredOrder, outputOrder) =>
outputOrder.satisfies(outputOrder.copy(child = requiredOrder))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Add, AggregateWindowFunction, Ascending, Attribute, BoundReference, CurrentRow, DateAdd, DateAddYMInterval, DecimalAddNoOverflowCheck, Descending, Expression, ExtractANSIIntervalDays, FrameLessOffsetWindowFunction, FrameType, IdentityProjection, IntegerLiteral, MutableProjection, NamedExpression, OffsetWindowFunction, PythonFuncExpression, RangeFrame, RowFrame, RowOrdering, SortOrder, SpecifiedWindowFrame, TimestampAddInterval, TimestampAddYMInterval, UnaryMinus, UnboundedFollowing, UnboundedPreceding, UnsafeProjection, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.{Add, AggregateWindowFunction, Ascending, Attribute, BoundReference, Constant, CurrentRow, DateAdd, DateAddYMInterval, DecimalAddNoOverflowCheck, Descending, Expression, ExtractANSIIntervalDays, FrameLessOffsetWindowFunction, FrameType, IdentityProjection, IntegerLiteral, MutableProjection, NamedExpression, OffsetWindowFunction, PythonFuncExpression, RangeFrame, RowFrame, RowOrdering, SortOrder, SpecifiedWindowFrame, TimestampAddInterval, TimestampAddYMInterval, UnaryMinus, UnboundedFollowing, UnboundedPreceding, UnsafeProjection, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -95,7 +95,7 @@ trait WindowEvaluatorFactoryBase {
// Flip the sign of the offset when processing the order is descending
val boundOffset = sortExpr.direction match {
case Descending => UnaryMinus(offset)
case Ascending => offset
case Ascending | Constant => offset
}

// Create the projection which returns the current 'value' modified by adding the offset.
Expand Down
Loading