Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,7 @@ case class InMemoryRelation(
override def innerChildren: Seq[SparkPlan] = Seq(cachedPlan)

override def doCanonicalize(): logical.LogicalPlan =
copy(output = output.map(QueryPlan.normalizeExpressions(_, output)),
cacheBuilder,
outputOrdering)
withOutput(output.map(QueryPlan.normalizeExpressions(_, output)))

@transient val partitionStatistics = new PartitionStatistics(output)

Expand Down Expand Up @@ -412,8 +410,13 @@ case class InMemoryRelation(
}
}

def withOutput(newOutput: Seq[Attribute]): InMemoryRelation =
InMemoryRelation(newOutput, cacheBuilder, outputOrdering, statsOfPlanToCache)
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
val map = AttributeMap(output.zip(newOutput))
val newOutputOrdering = outputOrdering
.map(_.transform { case a: Attribute => map(a) })
.asInstanceOf[Seq[SortOrder]]
InMemoryRelation(newOutput, cacheBuilder, newOutputOrdering, statsOfPlanToCache)
}

override def newInstance(): this.type = {
InMemoryRelation(
Expand All @@ -430,6 +433,12 @@ case class InMemoryRelation(
cloned
}

override def makeCopy(newArgs: Array[AnyRef]): LogicalPlan = {
val copied = super.makeCopy(newArgs).asInstanceOf[InMemoryRelation]
copied.statsOfPlanToCache = this.statsOfPlanToCache
copied
}

override def simpleString(maxFields: Int): String =
s"InMemoryRelation [${truncatedString(output, ", ", maxFields)}], ${cacheBuilder.storageLevel}"
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeMap, AttributeSet, BitwiseAnd, Empty2Null, Expression, HiveHash, Literal, NamedExpression, Pmod, SortOrder}
import org.apache.spark.sql.catalyst.optimizer.{EliminateSorts, FoldablePropagation}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -98,13 +99,15 @@ object V1Writes extends Rule[LogicalPlan] with SQLConfHelper {
assert(empty2NullPlan.output.length == query.output.length)
val attrMap = AttributeMap(query.output.zip(empty2NullPlan.output))

// Rewrite the attribute references in the required ordering to use the new output.
val requiredOrdering = write.requiredOrdering.map(_.transform {
case a: Attribute => attrMap.getOrElse(a, a)
}.asInstanceOf[SortOrder])
val outputOrdering = empty2NullPlan.outputOrdering
val orderingMatched = isOrderingMatched(requiredOrdering.map(_.child), outputOrdering)
if (orderingMatched) {
// Rewrite the attribute references in the required ordering to use the new output,
// then eliminate foldable ordering.
val requiredOrdering = {
val ordering = write.requiredOrdering.map(_.transform {
case a: Attribute => attrMap.getOrElse(a, a)
}.asInstanceOf[SortOrder])
eliminateFoldableOrdering(ordering, empty2NullPlan).outputOrdering
}
if (isOrderingMatched(requiredOrdering.map(_.child), empty2NullPlan.outputOrdering)) {
empty2NullPlan
} else {
Sort(requiredOrdering, global = false, empty2NullPlan)
Expand Down Expand Up @@ -200,6 +203,15 @@ object V1WritesUtils {
expressions.exists(_.exists(_.isInstanceOf[Empty2Null]))
}

// SPARK-53738: the required ordering inferred from table spec (partition, bucketing, etc.)
// may contain foldable sort ordering expressions, which causes the optimized query's output
// ordering mismatch, here we calculate the required ordering more accurately, by creating a
// fake Sort node with the input query, then remove the foldable sort ordering expressions.
def eliminateFoldableOrdering(ordering: Seq[SortOrder], query: LogicalPlan): LogicalPlan =
EliminateSorts(FoldablePropagation(Sort(ordering, global = false, query)))

// The comparison ignores SortDirection and NullOrdering since it doesn't matter
// for writing cases.
def isOrderingMatched(
requiredOrdering: Seq[Expression],
outputOrdering: Seq[SortOrder]): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,23 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils {
hasLogicalSort: Boolean,
orderingMatched: Boolean,
hasEmpty2Null: Boolean = false)(query: => Unit): Unit = {
var optimizedPlan: LogicalPlan = null
executeAndCheckOrderingAndCustomValidate(
hasLogicalSort, Some(orderingMatched), hasEmpty2Null)(query)(_ => ())
}

/**
* Execute a write query and check ordering of the plan, then do custom validation
*/
protected def executeAndCheckOrderingAndCustomValidate(
hasLogicalSort: Boolean,
orderingMatched: Option[Boolean],
hasEmpty2Null: Boolean = false)(query: => Unit)(
customValidate: LogicalPlan => Unit): Unit = {
@volatile var optimizedPlan: LogicalPlan = null

val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
val conf = qe.sparkSession.sessionState.conf
qe.optimizedPlan match {
case w: V1WriteCommand =>
if (hasLogicalSort && conf.getConf(SQLConf.PLANNED_WRITE_ENABLED)) {
Expand All @@ -84,9 +97,12 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils {

query

// Check whether the output ordering is matched before FileFormatWriter executes rdd.
assert(FileFormatWriter.outputOrderingMatched == orderingMatched,
s"Expect: $orderingMatched, Actual: ${FileFormatWriter.outputOrderingMatched}")
orderingMatched.foreach { matched =>
// Check whether the output ordering is matched before FileFormatWriter executes rdd.
assert(FileFormatWriter.outputOrderingMatched == matched,
s"Expect orderingMatched: $matched, " +
s"Actual: ${FileFormatWriter.outputOrderingMatched}")
}

sparkContext.listenerBus.waitUntilEmpty()

Expand All @@ -102,6 +118,8 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils {
assert(empty2nullExpr == hasEmpty2Null,
s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: $empty2nullExpr. Plan:\n$optimizedPlan")

customValidate(optimizedPlan)

spark.listenerManager.unregister(listener)
}
}
Expand Down Expand Up @@ -390,4 +408,33 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write
}
}
}

test("v1 write with sort by literal column preserve custom order") {
withPlannedWrite { enabled =>
withTable("t") {
sql(
"""
|CREATE TABLE t(i INT, j INT, k STRING) USING PARQUET
|PARTITIONED BY (k)
|""".stripMargin)
// Skip checking orderingMatched temporarily to avoid touching `FileFormatWriter`,
// see details at https://github.com/apache/spark/pull/52584#issuecomment-3407716019
executeAndCheckOrderingAndCustomValidate(
hasLogicalSort = true, orderingMatched = None) {
sql(
"""
|INSERT OVERWRITE t
|SELECT i, j, '0' as k FROM t0 SORT BY k, i
|""".stripMargin)
} { optimizedPlan =>
assert {
optimizedPlan.outputOrdering.exists {
case SortOrder(attr: AttributeReference, _, _, _) => attr.name == "i"
case _ => false
}
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.hive.execution.command

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SortOrder}
import org.apache.spark.sql.execution.datasources.V1WriteCommandSuiteBase
import org.apache.spark.sql.hive.test.TestHiveSingleton

Expand Down Expand Up @@ -105,4 +106,35 @@ class V1WriteHiveCommandSuite
}
}
}

test("v1 write to hive table with sort by literal column preserve custom order") {
withPlannedWrite { enabled =>
withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") {
withTable("t") {
sql(
"""
|CREATE TABLE t(i INT, j INT, k STRING) STORED AS PARQUET
|PARTITIONED BY (k)
|""".stripMargin)
// Skip checking orderingMatched temporarily to avoid touching `FileFormatWriter`,
// see details at https://github.com/apache/spark/pull/52584#issuecomment-3407716019
executeAndCheckOrderingAndCustomValidate(
hasLogicalSort = true, orderingMatched = None) {
sql(
"""
|INSERT OVERWRITE t
|SELECT i, j, '0' as k FROM t0 SORT BY k, i
|""".stripMargin)
} { optimizedPlan =>
assert {
optimizedPlan.outputOrdering.exists {
case SortOrder(attr: AttributeReference, _, _, _) => attr.name == "i"
case _ => false
}
}
}
}
}
}
}
}