diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 0596dc00985a..e79000d58350 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -287,7 +287,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { mapChildren(_.transformDown(rule)) } else { // If the transform function replaces this node with a new one, carry over the tags. - afterRule.tags ++= this.tags + afterRule.copyTagsFrom(this) afterRule.mapChildren(_.transformDown(rule)) } } @@ -311,7 +311,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } } // If the transform function replaces this node with a new one, carry over the tags. - newNode.tags ++= this.tags + newNode.copyTagsFrom(this) newNode } @@ -429,8 +429,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { private def makeCopy( newArgs: Array[AnyRef], allowEmptyArgs: Boolean): BaseType = attachTree(this, "makeCopy") { + val allCtors = getClass.getConstructors + if (newArgs.isEmpty && allCtors.isEmpty) { + // This is a singleton object which doesn't have any constructor. Just return `this` as we + // can't copy it. + return this + } + // Skip no-arg constructors that are just there for kryo. - val ctors = getClass.getConstructors.filter(allowEmptyArgs || _.getParameterTypes.size != 0) + val ctors = allCtors.filter(allowEmptyArgs || _.getParameterTypes.size != 0) if (ctors.isEmpty) { sys.error(s"No valid constructor for $nodeName") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index c8531e9a046a..1583b8d3a1f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.StringUtils.{PlanStringConcat, StringConcat} +import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.adaptive.InsertAdaptiveSparkPlan import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} @@ -60,36 +60,38 @@ class QueryExecution( lazy val analyzed: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.ANALYSIS) { SparkSession.setActiveSession(sparkSession) + // We can't clone `logical` here, which will reset the `_analyzed` flag. sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker) } lazy val withCachedData: LogicalPlan = { assertAnalyzed() assertSupported() - sparkSession.sharedState.cacheManager.useCachedData(analyzed) + // clone the plan to avoid sharing the plan instance between different stages like analyzing, + // optimizing and planning. + sparkSession.sharedState.cacheManager.useCachedData(analyzed.clone()) } lazy val optimizedPlan: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.OPTIMIZATION) { - sparkSession.sessionState.optimizer.executeAndTrack(withCachedData, tracker) + // clone the plan to avoid sharing the plan instance between different stages like analyzing, + // optimizing and planning. + sparkSession.sessionState.optimizer.executeAndTrack(withCachedData.clone(), tracker) } lazy val sparkPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) { SparkSession.setActiveSession(sparkSession) - // Runtime re-optimization requires a unique instance of every node in the logical plan. - val logicalPlan = if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) { - optimizedPlan.clone() - } else { - optimizedPlan - } // TODO: We use next(), i.e. take the first plan returned by the planner, here for now, // but we will implement to choose the best plan. - planner.plan(ReturnAnswer(logicalPlan)).next() + // Clone the logical plan here, in case the planner rules change the states of the logical plan. + planner.plan(ReturnAnswer(optimizedPlan.clone())).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. lazy val executedPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) { - prepareForExecution(sparkPlan) + // clone the plan to avoid sharing the plan instance between different stages like analyzing, + // optimizing and planning. + prepareForExecution(sparkPlan.clone()) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 1de2b6e0a85d..b77f90d19b62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -223,6 +223,13 @@ case class InMemoryRelation( statsOfPlanToCache).asInstanceOf[this.type] } + // override `clone` since the default implementation won't carry over mutable states. + override def clone(): LogicalPlan = { + val cloned = this.copy() + cloned.statsOfPlanToCache = this.statsOfPlanToCache + cloned + } + override def simpleString(maxFields: Int): String = s"InMemoryRelation [${truncatedString(output, ", ", maxFields)}], ${cacheBuilder.storageLevel}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 45c62b467657..39b08e2894dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.IgnoreCachedData +import org.apache.spark.sql.catalyst.plans.logical.{IgnoreCachedData, LogicalPlan} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types.{StringType, StructField, StructType} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index f29e7869fb27..a1de287b93f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -52,4 +52,10 @@ case class SaveIntoDataSourceCommand( val redacted = SQLConf.get.redactOptions(options) s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}" } + + // Override `clone` since the default implementation will turn `CaseInsensitiveMap` to a normal + // map. + override def clone(): LogicalPlan = { + SaveIntoDataSourceCommand(query.clone(), dataSource, options, mode) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 7de5e826f667..39c87c9eeb47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.execution import scala.io.Source -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, FastOperator} +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -137,5 +139,56 @@ class QueryExecutionSuite extends SharedSQLContext { (_: LogicalPlan) => throw new Error("error")) val error = intercept[Error](qe.toString) assert(error.getMessage.contains("error")) + + spark.experimental.extraStrategies = Nil + } + + test("SPARK-28346: clone the query plan between different stages") { + val tag1 = new TreeNodeTag[String]("a") + val tag2 = new TreeNodeTag[String]("b") + val tag3 = new TreeNodeTag[String]("c") + + def assertNoTag(tag: TreeNodeTag[String], plans: QueryPlan[_]*): Unit = { + plans.foreach { plan => + assert(plan.getTagValue(tag).isEmpty) + } + } + + val df = spark.range(10) + val analyzedPlan = df.queryExecution.analyzed + val cachedPlan = df.queryExecution.withCachedData + val optimizedPlan = df.queryExecution.optimizedPlan + + analyzedPlan.setTagValue(tag1, "v") + assertNoTag(tag1, cachedPlan, optimizedPlan) + + cachedPlan.setTagValue(tag2, "v") + assertNoTag(tag2, analyzedPlan, optimizedPlan) + + optimizedPlan.setTagValue(tag3, "v") + assertNoTag(tag3, analyzedPlan, cachedPlan) + + val tag4 = new TreeNodeTag[String]("d") + try { + spark.experimental.extraStrategies = Seq(new SparkStrategy() { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + plan.foreach { + case r: org.apache.spark.sql.catalyst.plans.logical.Range => + r.setTagValue(tag4, "v") + case _ => + } + Seq(FastOperator(plan.output)) + } + }) + // trigger planning + df.queryExecution.sparkPlan + assert(optimizedPlan.getTagValue(tag4).isEmpty) + } finally { + spark.experimental.extraStrategies = Nil + } + + val tag5 = new TreeNodeTag[String]("e") + df.queryExecution.executedPlan.setTagValue(tag5, "v") + assertNoTag(tag5, df.queryExecution.sparkPlan) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index b3a5c687f775..b4c1472ecb81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -187,7 +187,7 @@ class PartitionBatchPruningSuite val result = df.collect().map(_(0)).toArray assert(result.length === 1) - val (readPartitions, readBatches) = df.queryExecution.sparkPlan.collect { + val (readPartitions, readBatches) = df.queryExecution.executedPlan.collect { case in: InMemoryTableScanExec => (in.readPartitions.value, in.readBatches.value) }.head assert(readPartitions === 5) @@ -208,7 +208,7 @@ class PartitionBatchPruningSuite df.collect().map(_(0)).toArray } - val (readPartitions, readBatches) = df.queryExecution.sparkPlan.collect { + val (readPartitions, readBatches) = df.queryExecution.executedPlan.collect { case in: InMemoryTableScanExec => (in.readPartitions.value, in.readBatches.value) }.head