diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 1d7cf5455e57b..e9e897cfb78ec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -151,6 +151,7 @@ package object dsl extends SQLConfHelper {
def asc: SortOrder = SortOrder(expr, Ascending)
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Seq.empty)
+ def const: SortOrder = SortOrder(expr, Constant)
def desc: SortOrder = SortOrder(expr, Descending)
def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Seq.empty)
def as(alias: String): NamedExpression = Alias(expr, alias)()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index 824024a84cbad..73849f831237c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -45,6 +45,11 @@ case object Descending extends SortDirection {
override def defaultNullOrdering: NullOrdering = NullsLast
}
+case object Constant extends SortDirection {
+ override def sql: String = "CONST"
+ override def defaultNullOrdering: NullOrdering = NullsFirst
+}
+
case object NullsFirst extends NullOrdering {
override def sql: String = "NULLS FIRST"
}
@@ -69,8 +74,13 @@ case class SortOrder(
override def children: Seq[Expression] = child +: sameOrderExpressions
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForOrderingExpr(dataType, prettyName)
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (direction == Constant) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeUtils.checkForOrderingExpr(dataType, prettyName)
+ }
+ }
override def dataType: DataType = child.dataType
override def nullable: Boolean = child.nullable
@@ -81,8 +91,8 @@ case class SortOrder(
def isAscending: Boolean = direction == Ascending
def satisfies(required: SortOrder): Boolean = {
- children.exists(required.child.semanticEquals) &&
- direction == required.direction && nullOrdering == required.nullOrdering
+ children.exists(required.child.semanticEquals) && (direction == Constant ||
+ direction == required.direction && nullOrdering == required.nullOrdering)
}
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): SortOrder =
@@ -101,21 +111,38 @@ object SortOrder {
* Returns if a sequence of SortOrder satisfies another sequence of SortOrder.
*
* SortOrder sequence A satisfies SortOrder sequence B if and only if B is an equivalent of A
- * or of A's prefix. Here are examples of ordering A satisfying ordering B:
+ * or of A's prefix, except for SortOrder in B that satisfies any constant SortOrder in A.
+ *
+ * Here are examples of ordering A satisfying ordering B:
*
* - ordering A is [x, y] and ordering B is [x]
+ * - ordering A is [z(const), x, y] and ordering B is [x, z]
* - ordering A is [x(sameOrderExpressions=x1)] and ordering B is [x1]
* - ordering A is [x(sameOrderExpressions=x1), y] and ordering B is [x1]
*
*/
- def orderingSatisfies(ordering1: Seq[SortOrder], ordering2: Seq[SortOrder]): Boolean = {
- if (ordering2.isEmpty) {
- true
- } else if (ordering2.length > ordering1.length) {
+ def orderingSatisfies(
+ providedOrdering: Seq[SortOrder], requiredOrdering: Seq[SortOrder]): Boolean = {
+ if (requiredOrdering.isEmpty) {
+ return true
+ }
+
+ val (constantProvidedOrdering, nonConstantProvidedOrdering) = providedOrdering.partition {
+ case SortOrder(_, Constant, _, _) => true
+ case SortOrder(child, _, _, _) => child.foldable
+ }
+
+ val effectiveRequiredOrdering = requiredOrdering.filterNot { requiredOrder =>
+ constantProvidedOrdering.exists { providedOrder =>
+ providedOrder.satisfies(requiredOrder)
+ }
+ }
+
+ if (effectiveRequiredOrdering.length > nonConstantProvidedOrdering.length) {
false
} else {
- ordering2.zip(ordering1).forall {
- case (o2, o1) => o1.satisfies(o2)
+ effectiveRequiredOrdering.zip(nonConstantProvidedOrdering).forall {
+ case (required, provided) => provided.satisfies(required)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
index 7a4f04bf04f7a..b61dae5850618 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWithinGroup}
-import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Constant, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
import org.apache.spark.sql.catalyst.expressions.Cast.toSQLExpr
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.PhysicalDataType
@@ -199,6 +199,8 @@ case class Mode(
this.copy(child = child, reverseOpt = Some(true))
case SortOrder(child, Descending, _, _) =>
this.copy(child = child, reverseOpt = Some(false))
+ case SortOrder(child, Constant, _, _) =>
+ this.copy(child = child)
}
case _ => this
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala
index 6dfa1b499df23..942c06f60d123 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala
@@ -382,7 +382,7 @@ case class PercentileCont(left: Expression, right: Expression, reverse: Boolean
nodeName, 1, orderingWithinGroup.length)
}
orderingWithinGroup.head match {
- case SortOrder(child, Ascending, _, _) => this.copy(left = child)
+ case SortOrder(child, Ascending | Constant, _, _) => this.copy(left = child)
case SortOrder(child, Descending, _, _) => this.copy(left = child, reverse = true)
}
}
@@ -440,7 +440,7 @@ case class PercentileDisc(
nodeName, 1, orderingWithinGroup.length)
}
orderingWithinGroup.head match {
- case SortOrder(expr, Ascending, _, _) => this.copy(child = expr)
+ case SortOrder(expr, Ascending | Constant, _, _) => this.copy(child = expr)
case SortOrder(expr, Descending, _, _) => this.copy(child = expr, reverse = true)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index fc65c24afcb8f..b5941cd5eddcf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1919,7 +1919,7 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
object EliminateSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(_.containsPattern(SORT)) {
case s @ Sort(orders, _, child, _) if orders.isEmpty || orders.exists(_.child.foldable) =>
- val newOrders = orders.filterNot(_.child.foldable)
+ val newOrders = orders.filterNot(o => o.direction != Constant && o.child.foldable)
if (newOrders.isEmpty) {
child
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala
index e1a9e8b5ea810..efbd7b0c8b810 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans
import scala.collection.mutable
import org.apache.spark.sql.catalyst.SQLConfHelper
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Empty2Null, Expression, NamedExpression, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Constant, Empty2Null, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.internal.SQLConf
/**
@@ -128,6 +128,8 @@ trait AliasAwareQueryOutputOrdering[T <: QueryPlan[T]]
}
}
}
- newOrdering.takeWhile(_.isDefined).flatten.toSeq
+ newOrdering.takeWhile(_.isDefined).flatten.toSeq ++ outputExpressions.collect {
+ case a @ Alias(child, _) if child.foldable => SortOrder(a.toAttribute, Constant)
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index ad6939422b976..5a7548f13f1f9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -912,7 +912,8 @@ case class Sort(
override def maxRowsPerPartition: Option[Long] = {
if (global) maxRows else child.maxRowsPerPartition
}
- override def outputOrdering: Seq[SortOrder] = order
+ override def outputOrdering: Seq[SortOrder] =
+ order ++ child.outputOrdering.filter(_.direction == Constant)
final override val nodePatterns: Seq[TreePattern] = Seq(SORT)
override protected def withNewChildInternal(newChild: LogicalPlan): Sort = copy(child = newChild)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala
index 06c8b5ccef652..5facdaeb1aca1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala
@@ -43,9 +43,10 @@ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper {
val sortOrder = direction match {
case Ascending => BoundReference(0, dataType, nullable = true).asc
case Descending => BoundReference(0, dataType, nullable = true).desc
+ case Constant => BoundReference(0, dataType, nullable = true).const
}
val expectedCompareResult = direction match {
- case Ascending => signum(expected)
+ case Ascending | Constant => signum(expected)
case Descending => -1 * signum(expected)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index 11fde41aae9e4..708e054664c57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -46,7 +46,8 @@ case class SortExec(
override def output: Seq[Attribute] = child.output
- override def outputOrdering: Seq[SortOrder] = sortOrder
+ override def outputOrdering: Seq[SortOrder] =
+ sortOrder ++ child.outputOrdering.filter(_.direction == Constant)
// sort performed is local within a given partition so will retain
// child operator's partitioning
@@ -73,15 +74,17 @@ case class SortExec(
* should make it public.
*/
def createSorter(): UnsafeExternalRowSorter = {
+ val effectiveSortOrder = sortOrder.filterNot(_.direction == Constant)
+
rowSorter = new ThreadLocal[UnsafeExternalRowSorter]()
val ordering = RowOrdering.create(sortOrder, output)
// The comparator for comparing prefix
- val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)
+ val boundSortExpression = BindReferences.bindReference(effectiveSortOrder.head, output)
val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)
- val canUseRadixSort = enableRadixSort && sortOrder.length == 1 &&
+ val canUseRadixSort = enableRadixSort && effectiveSortOrder.length == 1 &&
SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression)
// The generator for prefix
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index 4b561b813067e..36dd7a76c0d21 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -63,6 +63,8 @@ object SortPrefixUtils {
PrefixComparators.STRING_DESC_NULLS_FIRST
case Descending =>
PrefixComparators.STRING_DESC
+ case Constant =>
+ NoOpPrefixComparator
}
}
@@ -76,6 +78,8 @@ object SortPrefixUtils {
PrefixComparators.BINARY_DESC_NULLS_FIRST
case Descending =>
PrefixComparators.BINARY_DESC
+ case Constant =>
+ NoOpPrefixComparator
}
}
@@ -89,6 +93,8 @@ object SortPrefixUtils {
PrefixComparators.LONG_DESC_NULLS_FIRST
case Descending =>
PrefixComparators.LONG_DESC
+ case Constant =>
+ NoOpPrefixComparator
}
}
@@ -102,6 +108,8 @@ object SortPrefixUtils {
PrefixComparators.DOUBLE_DESC_NULLS_FIRST
case Descending =>
PrefixComparators.DOUBLE_DESC
+ case Constant =>
+ NoOpPrefixComparator
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index eabbc7fc74f50..bf7491625fa03 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -439,9 +439,7 @@ case class InMemoryRelation(
override def innerChildren: Seq[SparkPlan] = Seq(cachedPlan)
override def doCanonicalize(): logical.LogicalPlan =
- copy(output = output.map(QueryPlan.normalizeExpressions(_, output)),
- cacheBuilder,
- outputOrdering)
+ withOutput(output.map(QueryPlan.normalizeExpressions(_, output)))
@transient val partitionStatistics = new PartitionStatistics(output)
@@ -469,8 +467,13 @@ case class InMemoryRelation(
}
}
- def withOutput(newOutput: Seq[Attribute]): InMemoryRelation =
- InMemoryRelation(newOutput, cacheBuilder, outputOrdering, statsOfPlanToCache)
+ def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
+ val map = AttributeMap(output.zip(newOutput))
+ val newOutputOrdering = outputOrdering
+ .map(_.transform { case a: Attribute => map(a) })
+ .asInstanceOf[Seq[SortOrder]]
+ InMemoryRelation(newOutput, cacheBuilder, newOutputOrdering, statsOfPlanToCache)
+ }
override def newInstance(): this.type = {
InMemoryRelation(
@@ -487,6 +490,12 @@ case class InMemoryRelation(
cloned
}
+ override def makeCopy(newArgs: Array[AnyRef]): LogicalPlan = {
+ val copied = super.makeCopy(newArgs).asInstanceOf[InMemoryRelation]
+ copied.statsOfPlanToCache = this.statsOfPlanToCache
+ copied
+ }
+
override def simpleString(maxFields: Int): String =
s"InMemoryRelation [${truncatedString(output, ", ", maxFields)}], ${cacheBuilder.storageLevel}"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2e47f08ac115a..95bba45c3b1d9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -25,6 +25,7 @@ import scala.jdk.CollectionConverters._
import org.apache.hadoop.fs.Path
+import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.PREDICATES
import org.apache.spark.rdd.RDD
@@ -827,6 +828,8 @@ object DataSourceStrategy
val directionV2 = directionV1 match {
case Ascending => SortDirection.ASCENDING
case Descending => SortDirection.DESCENDING
+ case Constant =>
+ throw SparkException.internalError(s"Unexpected catalyst sort direction $Constant")
}
val nullOrderingV2 = nullOrderingV1 match {
case NullsFirst => NullOrdering.NULLS_FIRST
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala
index 280fe1068d814..53e2d3f74bb34 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
-import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeMap, AttributeSet, BitwiseAnd, Empty2Null, Expression, HiveHash, Literal, NamedExpression, Pmod, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeMap, AttributeSet, BitwiseAnd, Constant, Empty2Null, Expression, HiveHash, Literal, NamedExpression, Pmod, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.rules.Rule
@@ -199,13 +199,27 @@ object V1WritesUtils {
expressions.exists(_.exists(_.isInstanceOf[Empty2Null]))
}
+ // SortOrder sequence A (outputOrdering) satisfies SortOrder sequence B (requiredOrdering)
+ // if and only if B is an equivalent of A or of A's prefix, except for SortOrder in B that
+ // satisfies any constant SortOrder in A.
def isOrderingMatched(
requiredOrdering: Seq[Expression],
outputOrdering: Seq[SortOrder]): Boolean = {
- if (requiredOrdering.length > outputOrdering.length) {
+ val (constantOutputOrdering, nonConstantOutputOrdering) = outputOrdering.partition {
+ case SortOrder(_, Constant, _, _) => true
+ case SortOrder(child, _, _, _) => child.foldable
+ }
+
+ val effectiveRequiredOrdering = requiredOrdering.filterNot { requiredOrder =>
+ constantOutputOrdering.exists { outputOrder =>
+ outputOrder.satisfies(outputOrder.copy(child = requiredOrder))
+ }
+ }
+
+ if (effectiveRequiredOrdering.length > nonConstantOutputOrdering.length) {
false
} else {
- requiredOrdering.zip(outputOrdering).forall {
+ effectiveRequiredOrdering.zip(nonConstantOutputOrdering).forall {
case (requiredOrder, outputOrder) =>
outputOrder.satisfies(outputOrder.copy(child = requiredOrder))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala
index c2dedda832e2e..3ce11acfc4d1e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Add, AggregateWindowFunction, Ascending, Attribute, BoundReference, CurrentRow, DateAdd, DateAddYMInterval, DecimalAddNoOverflowCheck, Descending, Expression, ExtractANSIIntervalDays, FrameLessOffsetWindowFunction, FrameType, IdentityProjection, IntegerLiteral, MutableProjection, NamedExpression, OffsetWindowFunction, PythonFuncExpression, RangeFrame, RowFrame, RowOrdering, SortOrder, SpecifiedWindowFrame, TimestampAddInterval, TimestampAddYMInterval, UnaryMinus, UnboundedFollowing, UnboundedPreceding, UnsafeProjection, WindowExpression}
+import org.apache.spark.sql.catalyst.expressions.{Add, AggregateWindowFunction, Ascending, Attribute, BoundReference, Constant, CurrentRow, DateAdd, DateAddYMInterval, DecimalAddNoOverflowCheck, Descending, Expression, ExtractANSIIntervalDays, FrameLessOffsetWindowFunction, FrameType, IdentityProjection, IntegerLiteral, MutableProjection, NamedExpression, OffsetWindowFunction, PythonFuncExpression, RangeFrame, RowFrame, RowOrdering, SortOrder, SpecifiedWindowFrame, TimestampAddInterval, TimestampAddYMInterval, UnaryMinus, UnboundedFollowing, UnboundedPreceding, UnsafeProjection, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
@@ -95,7 +95,7 @@ trait WindowEvaluatorFactoryBase {
// Flip the sign of the offset when processing the order is descending
val boundOffset = sortExpr.direction match {
case Descending => UnaryMinus(offset)
- case Ascending => offset
+ case Ascending | Constant => offset
}
// Create the projection which returns the current 'value' modified by adding the offset.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
index 80d771428d909..c16b029067240 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
@@ -63,10 +63,23 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils with AdaptiveSparkPlanHelper
hasLogicalSort: Boolean,
orderingMatched: Boolean,
hasEmpty2Null: Boolean = false)(query: => Unit): Unit = {
- var optimizedPlan: LogicalPlan = null
+ executeAndCheckOrderingAndCustomValidate(
+ hasLogicalSort, orderingMatched, hasEmpty2Null)(query)(_ => ())
+ }
+
+ /**
+ * Execute a write query and check ordering of the plan, then do custom validation
+ */
+ protected def executeAndCheckOrderingAndCustomValidate(
+ hasLogicalSort: Boolean,
+ orderingMatched: Boolean,
+ hasEmpty2Null: Boolean = false)(query: => Unit)(
+ customValidate: LogicalPlan => Unit): Unit = {
+ @volatile var optimizedPlan: LogicalPlan = null
val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
+ val conf = qe.sparkSession.sessionState.conf
qe.optimizedPlan match {
case w: V1WriteCommand =>
if (hasLogicalSort && conf.getConf(SQLConf.PLANNED_WRITE_ENABLED)) {
@@ -87,7 +100,8 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils with AdaptiveSparkPlanHelper
// Check whether the output ordering is matched before FileFormatWriter executes rdd.
assert(FileFormatWriter.outputOrderingMatched == orderingMatched,
- s"Expect: $orderingMatched, Actual: ${FileFormatWriter.outputOrderingMatched}")
+ s"Expect orderingMatched: $orderingMatched, " +
+ s"Actual: ${FileFormatWriter.outputOrderingMatched}")
sparkContext.listenerBus.waitUntilEmpty()
@@ -103,6 +117,8 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils with AdaptiveSparkPlanHelper
assert(empty2nullExpr == hasEmpty2Null,
s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: $empty2nullExpr. Plan:\n$optimizedPlan")
+ customValidate(optimizedPlan)
+
spark.listenerManager.unregister(listener)
}
}
@@ -391,4 +407,30 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write
}
}
}
+
+ test("v1 write with sort by literal column preserve custom order") {
+ withPlannedWrite { _ =>
+ withTable("t") {
+ sql(
+ """
+ |CREATE TABLE t(i INT, j INT, k STRING) USING PARQUET
+ |PARTITIONED BY (k)
+ |""".stripMargin)
+ executeAndCheckOrderingAndCustomValidate(hasLogicalSort = true, orderingMatched = true) {
+ sql(
+ """
+ |INSERT OVERWRITE t
+ |SELECT i, j, '0' as k FROM t0 SORT BY k, i
+ |""".stripMargin)
+ } { optimizedPlan =>
+ assert {
+ optimizedPlan.outputOrdering.exists {
+ case SortOrder(attr: AttributeReference, _, _, _) => attr.name == "i"
+ case _ => false
+ }
+ }
+ }
+ }
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala
index e0e056be5987c..55bb2c60dcca5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive.execution.command
import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SortOrder}
import org.apache.spark.sql.execution.datasources.V1WriteCommandSuiteBase
import org.apache.spark.sql.hive.HiveUtils._
import org.apache.spark.sql.hive.test.TestHiveSingleton
@@ -126,4 +127,35 @@ class V1WriteHiveCommandSuite
}
}
}
+
+ test("v1 write to hive table with sort by literal column preserve custom order") {
+ withCovnertMetastore { _ =>
+ withPlannedWrite { _ =>
+ withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") {
+ withTable("t") {
+ sql(
+ """
+ |CREATE TABLE t(i INT, j INT, k STRING) STORED AS PARQUET
+ |PARTITIONED BY (k)
+ |""".stripMargin)
+ executeAndCheckOrderingAndCustomValidate(
+ hasLogicalSort = true, orderingMatched = true) {
+ sql(
+ """
+ |INSERT OVERWRITE t
+ |SELECT i, j, '0' as k FROM t0 SORT BY k, i
+ |""".stripMargin)
+ } { optimizedPlan =>
+ assert {
+ optimizedPlan.outputOrdering.exists {
+ case SortOrder(attr: AttributeReference, _, _, _) => attr.name == "i"
+ case _ => false
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
}