diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 1d7cf5455e57b..e9e897cfb78ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -151,6 +151,7 @@ package object dsl extends SQLConfHelper { def asc: SortOrder = SortOrder(expr, Ascending) def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Seq.empty) + def const: SortOrder = SortOrder(expr, Constant) def desc: SortOrder = SortOrder(expr, Descending) def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Seq.empty) def as(alias: String): NamedExpression = Alias(expr, alias)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 824024a84cbad..73849f831237c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -45,6 +45,11 @@ case object Descending extends SortDirection { override def defaultNullOrdering: NullOrdering = NullsLast } +case object Constant extends SortDirection { + override def sql: String = "CONST" + override def defaultNullOrdering: NullOrdering = NullsFirst +} + case object NullsFirst extends NullOrdering { override def sql: String = "NULLS FIRST" } @@ -69,8 +74,13 @@ case class SortOrder( override def children: Seq[Expression] = child +: sameOrderExpressions - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(dataType, prettyName) + override def checkInputDataTypes(): TypeCheckResult = { + if (direction == Constant) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeUtils.checkForOrderingExpr(dataType, prettyName) + } + } override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable @@ -81,8 +91,8 @@ case class SortOrder( def isAscending: Boolean = direction == Ascending def satisfies(required: SortOrder): Boolean = { - children.exists(required.child.semanticEquals) && - direction == required.direction && nullOrdering == required.nullOrdering + children.exists(required.child.semanticEquals) && (direction == Constant || + direction == required.direction && nullOrdering == required.nullOrdering) } override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): SortOrder = @@ -101,21 +111,38 @@ object SortOrder { * Returns if a sequence of SortOrder satisfies another sequence of SortOrder. * * SortOrder sequence A satisfies SortOrder sequence B if and only if B is an equivalent of A - * or of A's prefix. Here are examples of ordering A satisfying ordering B: + * or of A's prefix, except for SortOrder in B that satisfies any constant SortOrder in A. + * + * Here are examples of ordering A satisfying ordering B: * */ - def orderingSatisfies(ordering1: Seq[SortOrder], ordering2: Seq[SortOrder]): Boolean = { - if (ordering2.isEmpty) { - true - } else if (ordering2.length > ordering1.length) { + def orderingSatisfies( + providedOrdering: Seq[SortOrder], requiredOrdering: Seq[SortOrder]): Boolean = { + if (requiredOrdering.isEmpty) { + return true + } + + val (constantProvidedOrdering, nonConstantProvidedOrdering) = providedOrdering.partition { + case SortOrder(_, Constant, _, _) => true + case SortOrder(child, _, _, _) => child.foldable + } + + val effectiveRequiredOrdering = requiredOrdering.filterNot { requiredOrder => + constantProvidedOrdering.exists { providedOrder => + providedOrder.satisfies(requiredOrder) + } + } + + if (effectiveRequiredOrdering.length > nonConstantProvidedOrdering.length) { false } else { - ordering2.zip(ordering1).forall { - case (o2, o1) => o1.satisfies(o2) + effectiveRequiredOrdering.zip(nonConstantProvidedOrdering).forall { + case (required, provided) => provided.satisfies(required) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala index 7a4f04bf04f7a..b61dae5850618 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWithinGroup} -import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Constant, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder} import org.apache.spark.sql.catalyst.expressions.Cast.toSQLExpr import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.types.PhysicalDataType @@ -199,6 +199,8 @@ case class Mode( this.copy(child = child, reverseOpt = Some(true)) case SortOrder(child, Descending, _, _) => this.copy(child = child, reverseOpt = Some(false)) + case SortOrder(child, Constant, _, _) => + this.copy(child = child) } case _ => this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala index 6dfa1b499df23..942c06f60d123 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala @@ -382,7 +382,7 @@ case class PercentileCont(left: Expression, right: Expression, reverse: Boolean nodeName, 1, orderingWithinGroup.length) } orderingWithinGroup.head match { - case SortOrder(child, Ascending, _, _) => this.copy(left = child) + case SortOrder(child, Ascending | Constant, _, _) => this.copy(left = child) case SortOrder(child, Descending, _, _) => this.copy(left = child, reverse = true) } } @@ -440,7 +440,7 @@ case class PercentileDisc( nodeName, 1, orderingWithinGroup.length) } orderingWithinGroup.head match { - case SortOrder(expr, Ascending, _, _) => this.copy(child = expr) + case SortOrder(expr, Ascending | Constant, _, _) => this.copy(child = expr) case SortOrder(expr, Descending, _, _) => this.copy(child = expr, reverse = true) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index fc65c24afcb8f..b5941cd5eddcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1919,7 +1919,7 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { object EliminateSorts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(_.containsPattern(SORT)) { case s @ Sort(orders, _, child, _) if orders.isEmpty || orders.exists(_.child.foldable) => - val newOrders = orders.filterNot(_.child.foldable) + val newOrders = orders.filterNot(o => o.direction != Constant && o.child.foldable) if (newOrders.isEmpty) { child } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala index e1a9e8b5ea810..efbd7b0c8b810 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/AliasAwareOutputExpression.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans import scala.collection.mutable import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Empty2Null, Expression, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Constant, Empty2Null, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.internal.SQLConf /** @@ -128,6 +128,8 @@ trait AliasAwareQueryOutputOrdering[T <: QueryPlan[T]] } } } - newOrdering.takeWhile(_.isDefined).flatten.toSeq + newOrdering.takeWhile(_.isDefined).flatten.toSeq ++ outputExpressions.collect { + case a @ Alias(child, _) if child.foldable => SortOrder(a.toAttribute, Constant) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ad6939422b976..5a7548f13f1f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -912,7 +912,8 @@ case class Sort( override def maxRowsPerPartition: Option[Long] = { if (global) maxRows else child.maxRowsPerPartition } - override def outputOrdering: Seq[SortOrder] = order + override def outputOrdering: Seq[SortOrder] = + order ++ child.outputOrdering.filter(_.direction == Constant) final override val nodePatterns: Seq[TreePattern] = Seq(SORT) override protected def withNewChildInternal(newChild: LogicalPlan): Sort = copy(child = newChild) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala index 06c8b5ccef652..5facdaeb1aca1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala @@ -43,9 +43,10 @@ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { val sortOrder = direction match { case Ascending => BoundReference(0, dataType, nullable = true).asc case Descending => BoundReference(0, dataType, nullable = true).desc + case Constant => BoundReference(0, dataType, nullable = true).const } val expectedCompareResult = direction match { - case Ascending => signum(expected) + case Ascending | Constant => signum(expected) case Descending => -1 * signum(expected) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 11fde41aae9e4..708e054664c57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -46,7 +46,8 @@ case class SortExec( override def output: Seq[Attribute] = child.output - override def outputOrdering: Seq[SortOrder] = sortOrder + override def outputOrdering: Seq[SortOrder] = + sortOrder ++ child.outputOrdering.filter(_.direction == Constant) // sort performed is local within a given partition so will retain // child operator's partitioning @@ -73,15 +74,17 @@ case class SortExec( * should make it public. */ def createSorter(): UnsafeExternalRowSorter = { + val effectiveSortOrder = sortOrder.filterNot(_.direction == Constant) + rowSorter = new ThreadLocal[UnsafeExternalRowSorter]() val ordering = RowOrdering.create(sortOrder, output) // The comparator for comparing prefix - val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) + val boundSortExpression = BindReferences.bindReference(effectiveSortOrder.head, output) val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) - val canUseRadixSort = enableRadixSort && sortOrder.length == 1 && + val canUseRadixSort = enableRadixSort && effectiveSortOrder.length == 1 && SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression) // The generator for prefix diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 4b561b813067e..36dd7a76c0d21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -63,6 +63,8 @@ object SortPrefixUtils { PrefixComparators.STRING_DESC_NULLS_FIRST case Descending => PrefixComparators.STRING_DESC + case Constant => + NoOpPrefixComparator } } @@ -76,6 +78,8 @@ object SortPrefixUtils { PrefixComparators.BINARY_DESC_NULLS_FIRST case Descending => PrefixComparators.BINARY_DESC + case Constant => + NoOpPrefixComparator } } @@ -89,6 +93,8 @@ object SortPrefixUtils { PrefixComparators.LONG_DESC_NULLS_FIRST case Descending => PrefixComparators.LONG_DESC + case Constant => + NoOpPrefixComparator } } @@ -102,6 +108,8 @@ object SortPrefixUtils { PrefixComparators.DOUBLE_DESC_NULLS_FIRST case Descending => PrefixComparators.DOUBLE_DESC + case Constant => + NoOpPrefixComparator } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index eabbc7fc74f50..bf7491625fa03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -439,9 +439,7 @@ case class InMemoryRelation( override def innerChildren: Seq[SparkPlan] = Seq(cachedPlan) override def doCanonicalize(): logical.LogicalPlan = - copy(output = output.map(QueryPlan.normalizeExpressions(_, output)), - cacheBuilder, - outputOrdering) + withOutput(output.map(QueryPlan.normalizeExpressions(_, output))) @transient val partitionStatistics = new PartitionStatistics(output) @@ -469,8 +467,13 @@ case class InMemoryRelation( } } - def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = - InMemoryRelation(newOutput, cacheBuilder, outputOrdering, statsOfPlanToCache) + def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { + val map = AttributeMap(output.zip(newOutput)) + val newOutputOrdering = outputOrdering + .map(_.transform { case a: Attribute => map(a) }) + .asInstanceOf[Seq[SortOrder]] + InMemoryRelation(newOutput, cacheBuilder, newOutputOrdering, statsOfPlanToCache) + } override def newInstance(): this.type = { InMemoryRelation( @@ -487,6 +490,12 @@ case class InMemoryRelation( cloned } + override def makeCopy(newArgs: Array[AnyRef]): LogicalPlan = { + val copied = super.makeCopy(newArgs).asInstanceOf[InMemoryRelation] + copied.statsOfPlanToCache = this.statsOfPlanToCache + copied + } + override def simpleString(maxFields: Int): String = s"InMemoryRelation [${truncatedString(output, ", ", maxFields)}], ${cacheBuilder.storageLevel}" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2e47f08ac115a..95bba45c3b1d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -25,6 +25,7 @@ import scala.jdk.CollectionConverters._ import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.PREDICATES import org.apache.spark.rdd.RDD @@ -827,6 +828,8 @@ object DataSourceStrategy val directionV2 = directionV1 match { case Ascending => SortDirection.ASCENDING case Descending => SortDirection.DESCENDING + case Constant => + throw SparkException.internalError(s"Unexpected catalyst sort direction $Constant") } val nullOrderingV2 = nullOrderingV1 match { case NullsFirst => NullOrdering.NULLS_FIRST diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala index 280fe1068d814..53e2d3f74bb34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeMap, AttributeSet, BitwiseAnd, Empty2Null, Expression, HiveHash, Literal, NamedExpression, Pmod, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeMap, AttributeSet, BitwiseAnd, Constant, Empty2Null, Expression, HiveHash, Literal, NamedExpression, Pmod, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.rules.Rule @@ -199,13 +199,27 @@ object V1WritesUtils { expressions.exists(_.exists(_.isInstanceOf[Empty2Null])) } + // SortOrder sequence A (outputOrdering) satisfies SortOrder sequence B (requiredOrdering) + // if and only if B is an equivalent of A or of A's prefix, except for SortOrder in B that + // satisfies any constant SortOrder in A. def isOrderingMatched( requiredOrdering: Seq[Expression], outputOrdering: Seq[SortOrder]): Boolean = { - if (requiredOrdering.length > outputOrdering.length) { + val (constantOutputOrdering, nonConstantOutputOrdering) = outputOrdering.partition { + case SortOrder(_, Constant, _, _) => true + case SortOrder(child, _, _, _) => child.foldable + } + + val effectiveRequiredOrdering = requiredOrdering.filterNot { requiredOrder => + constantOutputOrdering.exists { outputOrder => + outputOrder.satisfies(outputOrder.copy(child = requiredOrder)) + } + } + + if (effectiveRequiredOrdering.length > nonConstantOutputOrdering.length) { false } else { - requiredOrdering.zip(outputOrdering).forall { + effectiveRequiredOrdering.zip(nonConstantOutputOrdering).forall { case (requiredOrder, outputOrder) => outputOrder.satisfies(outputOrder.copy(child = requiredOrder)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala index c2dedda832e2e..3ce11acfc4d1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Add, AggregateWindowFunction, Ascending, Attribute, BoundReference, CurrentRow, DateAdd, DateAddYMInterval, DecimalAddNoOverflowCheck, Descending, Expression, ExtractANSIIntervalDays, FrameLessOffsetWindowFunction, FrameType, IdentityProjection, IntegerLiteral, MutableProjection, NamedExpression, OffsetWindowFunction, PythonFuncExpression, RangeFrame, RowFrame, RowOrdering, SortOrder, SpecifiedWindowFrame, TimestampAddInterval, TimestampAddYMInterval, UnaryMinus, UnboundedFollowing, UnboundedPreceding, UnsafeProjection, WindowExpression} +import org.apache.spark.sql.catalyst.expressions.{Add, AggregateWindowFunction, Ascending, Attribute, BoundReference, Constant, CurrentRow, DateAdd, DateAddYMInterval, DecimalAddNoOverflowCheck, Descending, Expression, ExtractANSIIntervalDays, FrameLessOffsetWindowFunction, FrameType, IdentityProjection, IntegerLiteral, MutableProjection, NamedExpression, OffsetWindowFunction, PythonFuncExpression, RangeFrame, RowFrame, RowOrdering, SortOrder, SpecifiedWindowFrame, TimestampAddInterval, TimestampAddYMInterval, UnaryMinus, UnboundedFollowing, UnboundedPreceding, UnsafeProjection, WindowExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf @@ -95,7 +95,7 @@ trait WindowEvaluatorFactoryBase { // Flip the sign of the offset when processing the order is descending val boundOffset = sortExpr.direction match { case Descending => UnaryMinus(offset) - case Ascending => offset + case Ascending | Constant => offset } // Create the projection which returns the current 'value' modified by adding the offset. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala index 80d771428d909..c16b029067240 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala @@ -63,10 +63,23 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils with AdaptiveSparkPlanHelper hasLogicalSort: Boolean, orderingMatched: Boolean, hasEmpty2Null: Boolean = false)(query: => Unit): Unit = { - var optimizedPlan: LogicalPlan = null + executeAndCheckOrderingAndCustomValidate( + hasLogicalSort, orderingMatched, hasEmpty2Null)(query)(_ => ()) + } + + /** + * Execute a write query and check ordering of the plan, then do custom validation + */ + protected def executeAndCheckOrderingAndCustomValidate( + hasLogicalSort: Boolean, + orderingMatched: Boolean, + hasEmpty2Null: Boolean = false)(query: => Unit)( + customValidate: LogicalPlan => Unit): Unit = { + @volatile var optimizedPlan: LogicalPlan = null val listener = new QueryExecutionListener { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + val conf = qe.sparkSession.sessionState.conf qe.optimizedPlan match { case w: V1WriteCommand => if (hasLogicalSort && conf.getConf(SQLConf.PLANNED_WRITE_ENABLED)) { @@ -87,7 +100,8 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils with AdaptiveSparkPlanHelper // Check whether the output ordering is matched before FileFormatWriter executes rdd. assert(FileFormatWriter.outputOrderingMatched == orderingMatched, - s"Expect: $orderingMatched, Actual: ${FileFormatWriter.outputOrderingMatched}") + s"Expect orderingMatched: $orderingMatched, " + + s"Actual: ${FileFormatWriter.outputOrderingMatched}") sparkContext.listenerBus.waitUntilEmpty() @@ -103,6 +117,8 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils with AdaptiveSparkPlanHelper assert(empty2nullExpr == hasEmpty2Null, s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: $empty2nullExpr. Plan:\n$optimizedPlan") + customValidate(optimizedPlan) + spark.listenerManager.unregister(listener) } } @@ -391,4 +407,30 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write } } } + + test("v1 write with sort by literal column preserve custom order") { + withPlannedWrite { _ => + withTable("t") { + sql( + """ + |CREATE TABLE t(i INT, j INT, k STRING) USING PARQUET + |PARTITIONED BY (k) + |""".stripMargin) + executeAndCheckOrderingAndCustomValidate(hasLogicalSort = true, orderingMatched = true) { + sql( + """ + |INSERT OVERWRITE t + |SELECT i, j, '0' as k FROM t0 SORT BY k, i + |""".stripMargin) + } { optimizedPlan => + assert { + optimizedPlan.outputOrdering.exists { + case SortOrder(attr: AttributeReference, _, _, _) => attr.name == "i" + case _ => false + } + } + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala index e0e056be5987c..55bb2c60dcca5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution.command import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SortOrder} import org.apache.spark.sql.execution.datasources.V1WriteCommandSuiteBase import org.apache.spark.sql.hive.HiveUtils._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -126,4 +127,35 @@ class V1WriteHiveCommandSuite } } } + + test("v1 write to hive table with sort by literal column preserve custom order") { + withCovnertMetastore { _ => + withPlannedWrite { _ => + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable("t") { + sql( + """ + |CREATE TABLE t(i INT, j INT, k STRING) STORED AS PARQUET + |PARTITIONED BY (k) + |""".stripMargin) + executeAndCheckOrderingAndCustomValidate( + hasLogicalSort = true, orderingMatched = true) { + sql( + """ + |INSERT OVERWRITE t + |SELECT i, j, '0' as k FROM t0 SORT BY k, i + |""".stripMargin) + } { optimizedPlan => + assert { + optimizedPlan.outputOrdering.exists { + case SortOrder(attr: AttributeReference, _, _, _) => attr.name == "i" + case _ => false + } + } + } + } + } + } + } + } }