From 92095b7dc47b86cb99a06ba21bc0d681728b62a8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 11 Jul 2019 14:00:07 +0800 Subject: [PATCH 1/5] clone the query plan between analyzer, optimizer and planner --- .../spark/sql/execution/QueryExecution.scala | 16 ++++------ .../execution/columnar/InMemoryRelation.scala | 29 +++++++++++++------ .../sql/execution/command/SetCommand.scala | 4 ++- .../spark/sql/execution/command/cache.scala | 2 ++ .../SaveIntoDataSourceCommand.scala | 4 +++ .../sql/execution/QueryExecutionSuite.scala | 26 ++++++++++++++++- .../columnar/PartitionBatchPruningSuite.scala | 4 +-- 7 files changed, 61 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index c8531e9a046a..9fa8de84ef67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.StringUtils.{PlanStringConcat, StringConcat} +import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.adaptive.InsertAdaptiveSparkPlan import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} @@ -66,30 +66,24 @@ class QueryExecution( lazy val withCachedData: LogicalPlan = { assertAnalyzed() assertSupported() - sparkSession.sharedState.cacheManager.useCachedData(analyzed) + sparkSession.sharedState.cacheManager.useCachedData(analyzed.clone()) } lazy val optimizedPlan: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.OPTIMIZATION) { - sparkSession.sessionState.optimizer.executeAndTrack(withCachedData, tracker) + sparkSession.sessionState.optimizer.executeAndTrack(withCachedData.clone(), tracker) } lazy val sparkPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) { SparkSession.setActiveSession(sparkSession) - // Runtime re-optimization requires a unique instance of every node in the logical plan. - val logicalPlan = if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) { - optimizedPlan.clone() - } else { - optimizedPlan - } // TODO: We use next(), i.e. take the first plan returned by the planner, here for now, // but we will implement to choose the best plan. - planner.plan(ReturnAnswer(logicalPlan)).next() + planner.plan(ReturnAnswer(optimizedPlan.clone())).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. lazy val executedPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) { - prepareForExecution(sparkPlan) + prepareForExecution(sparkPlan.clone()) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 1de2b6e0a85d..b326cfaf5a44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel @@ -149,14 +150,14 @@ object InMemoryRelation { logicalPlan: LogicalPlan): InMemoryRelation = { val cacheBuilder = CachedRDDBuilder(useCompression, batchSize, storageLevel, child, tableName) val relation = new InMemoryRelation(child.output, cacheBuilder, logicalPlan.outputOrdering) - relation.statsOfPlanToCache = logicalPlan.stats + relation.setStatsOfPlanToCache(logicalPlan.stats) relation } def apply(cacheBuilder: CachedRDDBuilder, logicalPlan: LogicalPlan): InMemoryRelation = { val relation = new InMemoryRelation( cacheBuilder.cachedPlan.output, cacheBuilder, logicalPlan.outputOrdering) - relation.statsOfPlanToCache = logicalPlan.stats + relation.setStatsOfPlanToCache(logicalPlan.stats) relation } @@ -166,9 +167,11 @@ object InMemoryRelation { outputOrdering: Seq[SortOrder], statsOfPlanToCache: Statistics): InMemoryRelation = { val relation = InMemoryRelation(output, cacheBuilder, outputOrdering) - relation.statsOfPlanToCache = statsOfPlanToCache + relation.setStatsOfPlanToCache(statsOfPlanToCache) relation } + + val STATS_OF_PLAN_TO_CACHE_TAG = new TreeNodeTag[Statistics]("stats_of_plan_to_cache") } case class InMemoryRelation( @@ -176,8 +179,15 @@ case class InMemoryRelation( @transient cacheBuilder: CachedRDDBuilder, override val outputOrdering: Seq[SortOrder]) extends logical.LeafNode with MultiInstanceRelation { + import InMemoryRelation.STATS_OF_PLAN_TO_CACHE_TAG + + def setStatsOfPlanToCache(statsOfPlanToCache: Statistics): Unit = { + setTagValue(STATS_OF_PLAN_TO_CACHE_TAG, statsOfPlanToCache) + } - @volatile var statsOfPlanToCache: Statistics = null + def getStatsOfPlanToCache(): Statistics = { + getTagValue(STATS_OF_PLAN_TO_CACHE_TAG).get + } override protected def innerChildren: Seq[SparkPlan] = Seq(cachedPlan) @@ -193,19 +203,20 @@ case class InMemoryRelation( private[sql] def updateStats( rowCount: Long, newColStats: Map[Attribute, ColumnStat]): Unit = this.synchronized { + val statsOfPlanToCache = getStatsOfPlanToCache() val newStats = statsOfPlanToCache.copy( rowCount = Some(rowCount), attributeStats = AttributeMap((statsOfPlanToCache.attributeStats ++ newColStats).toSeq) ) - statsOfPlanToCache = newStats + setStatsOfPlanToCache(newStats) } override def computeStats(): Statistics = { if (!cacheBuilder.isCachedColumnBuffersLoaded) { // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. - statsOfPlanToCache + getStatsOfPlanToCache() } else { - statsOfPlanToCache.copy( + getStatsOfPlanToCache().copy( sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue, rowCount = Some(cacheBuilder.rowCountStats.value.longValue) ) @@ -213,14 +224,14 @@ case class InMemoryRelation( } def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = - InMemoryRelation(newOutput, cacheBuilder, outputOrdering, statsOfPlanToCache) + InMemoryRelation(newOutput, cacheBuilder, outputOrdering, getStatsOfPlanToCache()) override def newInstance(): this.type = { InMemoryRelation( output.map(_.newInstance()), cacheBuilder, outputOrdering, - statsOfPlanToCache).asInstanceOf[this.type] + getStatsOfPlanToCache()).asInstanceOf[this.type] } override def simpleString(maxFields: Int): String = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 45c62b467657..ab86ff56c6a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.IgnoreCachedData +import org.apache.spark.sql.catalyst.plans.logical.{IgnoreCachedData, LogicalPlan} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types.{StringType, StructField, StructType} @@ -168,4 +168,6 @@ case object ResetCommand extends RunnableCommand with IgnoreCachedData { sparkSession.sessionState.conf.clear() Seq.empty[Row] } + + override def clone(): LogicalPlan = this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 7b00769308a4..cc89911faafb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -89,4 +89,6 @@ case object ClearCacheCommand extends RunnableCommand with IgnoreCachedData { sparkSession.catalog.clearCache() Seq.empty[Row] } + + override def clone(): LogicalPlan = this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index f29e7869fb27..2e3f19044f26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -52,4 +52,8 @@ case class SaveIntoDataSourceCommand( val redacted = SQLConf.get.redactOptions(options) s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}" } + + override def clone(): LogicalPlan = { + SaveIntoDataSourceCommand(query.clone(), dataSource, options, mode) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 7de5e826f667..1ace35efc591 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution import scala.io.Source import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, SubqueryAlias} +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -138,4 +139,27 @@ class QueryExecutionSuite extends SharedSQLContext { val error = intercept[Error](qe.toString) assert(error.getMessage.contains("error")) } + + test("analyzed plan should not change after it's generated") { + val df = spark.range(10).filter('id > 0).as("a") + val analyzedPlan = df.queryExecution.analyzed + val tag = new TreeNodeTag[String]("test") + analyzedPlan.setTagValue(tag, "tag") + + def checkPlan(l: LogicalPlan): Unit = { + assert(l.isInstanceOf[SubqueryAlias]) + val sub = l.asInstanceOf[SubqueryAlias] + assert(sub.child.isInstanceOf[Filter]) + assert(sub.getTagValue(tag).isDefined) + assert(sub.child.getTagValue(tag).isEmpty) + } + + checkPlan(analyzedPlan) + val df2 = df.filter('id > 0) + // trigger optimizaion + df2.queryExecution.optimizedPlan + + // The previous analyzed plan should not get changed. + checkPlan(analyzedPlan) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index b3a5c687f775..b4c1472ecb81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -187,7 +187,7 @@ class PartitionBatchPruningSuite val result = df.collect().map(_(0)).toArray assert(result.length === 1) - val (readPartitions, readBatches) = df.queryExecution.sparkPlan.collect { + val (readPartitions, readBatches) = df.queryExecution.executedPlan.collect { case in: InMemoryTableScanExec => (in.readPartitions.value, in.readBatches.value) }.head assert(readPartitions === 5) @@ -208,7 +208,7 @@ class PartitionBatchPruningSuite df.collect().map(_(0)).toArray } - val (readPartitions, readBatches) = df.queryExecution.sparkPlan.collect { + val (readPartitions, readBatches) = df.queryExecution.executedPlan.collect { case in: InMemoryTableScanExec => (in.readPartitions.value, in.readBatches.value) }.head From 5e0ab9a04ebc3b32b43b16fefbd8db94dc97ac4a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 15 Jul 2019 11:01:51 +0800 Subject: [PATCH 2/5] address comments --- .../apache/spark/sql/catalyst/trees/TreeNode.scala | 14 +++++++------- .../spark/sql/execution/QueryExecution.scala | 1 + 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 0596dc00985a..8cf02958f284 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.trees import java.util.UUID -import scala.collection.{mutable, Map} +import scala.collection.Map import scala.reflect.ClassTag import org.apache.commons.lang3.ClassUtils @@ -88,18 +88,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * A mutable map for holding auxiliary information of this tree node. It will be carried over * when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`. */ - private val tags: mutable.Map[TreeNodeTag[_], Any] = mutable.Map.empty + private val tags = new java.util.concurrent.ConcurrentHashMap[TreeNodeTag[_], Any]() protected def copyTagsFrom(other: BaseType): Unit = { - tags ++= other.tags + tags.putAll(other.tags) } def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = { - tags(tag) = value + tags.put(tag, value) } def getTagValue[T](tag: TreeNodeTag[T]): Option[T] = { - tags.get(tag).map(_.asInstanceOf[T]) + Option(tags.get(tag)).map(_.asInstanceOf[T]) } /** @@ -287,7 +287,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { mapChildren(_.transformDown(rule)) } else { // If the transform function replaces this node with a new one, carry over the tags. - afterRule.tags ++= this.tags + afterRule.copyTagsFrom(this) afterRule.mapChildren(_.transformDown(rule)) } } @@ -311,7 +311,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } } // If the transform function replaces this node with a new one, carry over the tags. - newNode.tags ++= this.tags + newNode.copyTagsFrom(this) newNode } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 9fa8de84ef67..42b6f94dead8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -60,6 +60,7 @@ class QueryExecution( lazy val analyzed: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.ANALYSIS) { SparkSession.setActiveSession(sparkSession) + // We can't clone `logical` here, which will reset the `_analyzed` flag. sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker) } From 6f9b59f5d360a041f6b825ddeb010dc03abee64c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 17 Jul 2019 01:34:21 +0800 Subject: [PATCH 3/5] simplify --- .../spark/sql/catalyst/trees/TreeNode.scala | 10 +++--- .../execution/columnar/InMemoryRelation.scala | 35 ++++++++----------- 2 files changed, 20 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 8cf02958f284..010e31fae293 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.trees import java.util.UUID -import scala.collection.Map +import scala.collection.{mutable, Map} import scala.reflect.ClassTag import org.apache.commons.lang3.ClassUtils @@ -88,18 +88,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * A mutable map for holding auxiliary information of this tree node. It will be carried over * when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`. */ - private val tags = new java.util.concurrent.ConcurrentHashMap[TreeNodeTag[_], Any]() + private val tags: mutable.Map[TreeNodeTag[_], Any] = mutable.Map.empty protected def copyTagsFrom(other: BaseType): Unit = { - tags.putAll(other.tags) + tags ++= other.tags } def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = { - tags.put(tag, value) + tags(tag) = value } def getTagValue[T](tag: TreeNodeTag[T]): Option[T] = { - Option(tags.get(tag)).map(_.asInstanceOf[T]) + tags.get(tag).map(_.asInstanceOf[T]) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index b326cfaf5a44..397614e9fdb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} -import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel @@ -150,14 +149,14 @@ object InMemoryRelation { logicalPlan: LogicalPlan): InMemoryRelation = { val cacheBuilder = CachedRDDBuilder(useCompression, batchSize, storageLevel, child, tableName) val relation = new InMemoryRelation(child.output, cacheBuilder, logicalPlan.outputOrdering) - relation.setStatsOfPlanToCache(logicalPlan.stats) + relation.statsOfPlanToCache = logicalPlan.stats relation } def apply(cacheBuilder: CachedRDDBuilder, logicalPlan: LogicalPlan): InMemoryRelation = { val relation = new InMemoryRelation( cacheBuilder.cachedPlan.output, cacheBuilder, logicalPlan.outputOrdering) - relation.setStatsOfPlanToCache(logicalPlan.stats) + relation.statsOfPlanToCache = logicalPlan.stats relation } @@ -167,11 +166,9 @@ object InMemoryRelation { outputOrdering: Seq[SortOrder], statsOfPlanToCache: Statistics): InMemoryRelation = { val relation = InMemoryRelation(output, cacheBuilder, outputOrdering) - relation.setStatsOfPlanToCache(statsOfPlanToCache) + relation.statsOfPlanToCache = statsOfPlanToCache relation } - - val STATS_OF_PLAN_TO_CACHE_TAG = new TreeNodeTag[Statistics]("stats_of_plan_to_cache") } case class InMemoryRelation( @@ -179,15 +176,8 @@ case class InMemoryRelation( @transient cacheBuilder: CachedRDDBuilder, override val outputOrdering: Seq[SortOrder]) extends logical.LeafNode with MultiInstanceRelation { - import InMemoryRelation.STATS_OF_PLAN_TO_CACHE_TAG - - def setStatsOfPlanToCache(statsOfPlanToCache: Statistics): Unit = { - setTagValue(STATS_OF_PLAN_TO_CACHE_TAG, statsOfPlanToCache) - } - def getStatsOfPlanToCache(): Statistics = { - getTagValue(STATS_OF_PLAN_TO_CACHE_TAG).get - } + @volatile var statsOfPlanToCache: Statistics = null override protected def innerChildren: Seq[SparkPlan] = Seq(cachedPlan) @@ -203,20 +193,19 @@ case class InMemoryRelation( private[sql] def updateStats( rowCount: Long, newColStats: Map[Attribute, ColumnStat]): Unit = this.synchronized { - val statsOfPlanToCache = getStatsOfPlanToCache() val newStats = statsOfPlanToCache.copy( rowCount = Some(rowCount), attributeStats = AttributeMap((statsOfPlanToCache.attributeStats ++ newColStats).toSeq) ) - setStatsOfPlanToCache(newStats) + statsOfPlanToCache = newStats } override def computeStats(): Statistics = { if (!cacheBuilder.isCachedColumnBuffersLoaded) { // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. - getStatsOfPlanToCache() + statsOfPlanToCache } else { - getStatsOfPlanToCache().copy( + statsOfPlanToCache.copy( sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue, rowCount = Some(cacheBuilder.rowCountStats.value.longValue) ) @@ -224,14 +213,20 @@ case class InMemoryRelation( } def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = - InMemoryRelation(newOutput, cacheBuilder, outputOrdering, getStatsOfPlanToCache()) + InMemoryRelation(newOutput, cacheBuilder, outputOrdering, statsOfPlanToCache) override def newInstance(): this.type = { InMemoryRelation( output.map(_.newInstance()), cacheBuilder, outputOrdering, - getStatsOfPlanToCache()).asInstanceOf[this.type] + statsOfPlanToCache).asInstanceOf[this.type] + } + + override def clone(): LogicalPlan = { + val cloned = this.copy() + cloned.statsOfPlanToCache = this.statsOfPlanToCache + cloned } override def simpleString(maxFields: Int): String = From 66f128139cab0067e56752472c7b1d0d09062843 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 19 Jul 2019 22:23:58 +0800 Subject: [PATCH 4/5] address comments --- .../spark/sql/catalyst/trees/TreeNode.scala | 9 +++- .../execution/columnar/InMemoryRelation.scala | 1 + .../sql/execution/command/SetCommand.scala | 2 - .../spark/sql/execution/command/cache.scala | 2 - .../SaveIntoDataSourceCommand.scala | 2 + .../sql/execution/QueryExecutionSuite.scala | 49 ++++++++++++------- 6 files changed, 41 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 010e31fae293..e79000d58350 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -429,8 +429,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { private def makeCopy( newArgs: Array[AnyRef], allowEmptyArgs: Boolean): BaseType = attachTree(this, "makeCopy") { + val allCtors = getClass.getConstructors + if (newArgs.isEmpty && allCtors.isEmpty) { + // This is a singleton object which doesn't have any constructor. Just return `this` as we + // can't copy it. + return this + } + // Skip no-arg constructors that are just there for kryo. - val ctors = getClass.getConstructors.filter(allowEmptyArgs || _.getParameterTypes.size != 0) + val ctors = allCtors.filter(allowEmptyArgs || _.getParameterTypes.size != 0) if (ctors.isEmpty) { sys.error(s"No valid constructor for $nodeName") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 397614e9fdb6..b77f90d19b62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -223,6 +223,7 @@ case class InMemoryRelation( statsOfPlanToCache).asInstanceOf[this.type] } + // override `clone` since the default implementation won't carry over mutable states. override def clone(): LogicalPlan = { val cloned = this.copy() cloned.statsOfPlanToCache = this.statsOfPlanToCache diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index ab86ff56c6a6..39b08e2894dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -168,6 +168,4 @@ case object ResetCommand extends RunnableCommand with IgnoreCachedData { sparkSession.sessionState.conf.clear() Seq.empty[Row] } - - override def clone(): LogicalPlan = this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index cc89911faafb..7b00769308a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -89,6 +89,4 @@ case object ClearCacheCommand extends RunnableCommand with IgnoreCachedData { sparkSession.catalog.clearCache() Seq.empty[Row] } - - override def clone(): LogicalPlan = this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 2e3f19044f26..a1de287b93f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -53,6 +53,8 @@ case class SaveIntoDataSourceCommand( s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}" } + // Override `clone` since the default implementation will turn `CaseInsensitiveMap` to a normal + // map. override def clone(): LogicalPlan = { SaveIntoDataSourceCommand(query.clone(), dataSource, options, mode) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 1ace35efc591..dbedbdc82534 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution import scala.io.Source import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, SubqueryAlias} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -138,28 +139,38 @@ class QueryExecutionSuite extends SharedSQLContext { (_: LogicalPlan) => throw new Error("error")) val error = intercept[Error](qe.toString) assert(error.getMessage.contains("error")) + + spark.experimental.extraStrategies = Nil } - test("analyzed plan should not change after it's generated") { - val df = spark.range(10).filter('id > 0).as("a") - val analyzedPlan = df.queryExecution.analyzed - val tag = new TreeNodeTag[String]("test") - analyzedPlan.setTagValue(tag, "tag") - - def checkPlan(l: LogicalPlan): Unit = { - assert(l.isInstanceOf[SubqueryAlias]) - val sub = l.asInstanceOf[SubqueryAlias] - assert(sub.child.isInstanceOf[Filter]) - assert(sub.getTagValue(tag).isDefined) - assert(sub.child.getTagValue(tag).isEmpty) + test("SPARK-28346: clone the query plan between analyzer, optimizer and planner") { + val tag1 = new TreeNodeTag[String]("a") + val tag2 = new TreeNodeTag[String]("b") + val tag3 = new TreeNodeTag[String]("c") + val tag4 = new TreeNodeTag[String]("d") + + def assertNoTag(tag: TreeNodeTag[String], plans: QueryPlan[_]*): Unit = { + plans.foreach { plan => + assert(plan.getTagValue(tag).isEmpty) + } } - checkPlan(analyzedPlan) - val df2 = df.filter('id > 0) - // trigger optimizaion - df2.queryExecution.optimizedPlan + val df = spark.range(10) + val analyzedPlan = df.queryExecution.analyzed + val optimizedPlan = df.queryExecution.optimizedPlan + val physicalPlan = df.queryExecution.sparkPlan + val finalPlan = df.queryExecution.executedPlan + + analyzedPlan.setTagValue(tag1, "v") + assertNoTag(tag1, optimizedPlan, physicalPlan, finalPlan) + + optimizedPlan.setTagValue(tag2, "v") + assertNoTag(tag2, analyzedPlan, physicalPlan, finalPlan) + + physicalPlan.setTagValue(tag3, "v") + assertNoTag(tag3, analyzedPlan, optimizedPlan, finalPlan) - // The previous analyzed plan should not get changed. - checkPlan(analyzedPlan) + finalPlan.setTagValue(tag4, "v") + assertNoTag(tag4, analyzedPlan, optimizedPlan, physicalPlan) } } From 4f75ba4db1455dcc35da88b86a46188ce568f37c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 23 Jul 2019 15:31:39 +0800 Subject: [PATCH 5/5] improve test --- .../spark/sql/execution/QueryExecution.scala | 7 ++++ .../sql/execution/QueryExecutionSuite.scala | 42 +++++++++++++------ 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 42b6f94dead8..1583b8d3a1f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -67,10 +67,14 @@ class QueryExecution( lazy val withCachedData: LogicalPlan = { assertAnalyzed() assertSupported() + // clone the plan to avoid sharing the plan instance between different stages like analyzing, + // optimizing and planning. sparkSession.sharedState.cacheManager.useCachedData(analyzed.clone()) } lazy val optimizedPlan: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.OPTIMIZATION) { + // clone the plan to avoid sharing the plan instance between different stages like analyzing, + // optimizing and planning. sparkSession.sessionState.optimizer.executeAndTrack(withCachedData.clone(), tracker) } @@ -78,12 +82,15 @@ class QueryExecution( SparkSession.setActiveSession(sparkSession) // TODO: We use next(), i.e. take the first plan returned by the planner, here for now, // but we will implement to choose the best plan. + // Clone the logical plan here, in case the planner rules change the states of the logical plan. planner.plan(ReturnAnswer(optimizedPlan.clone())).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. lazy val executedPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) { + // clone the plan to avoid sharing the plan instance between different stages like analyzing, + // optimizing and planning. prepareForExecution(sparkPlan.clone()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index dbedbdc82534..39c87c9eeb47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import scala.io.Source -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, FastOperator} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -143,11 +143,10 @@ class QueryExecutionSuite extends SharedSQLContext { spark.experimental.extraStrategies = Nil } - test("SPARK-28346: clone the query plan between analyzer, optimizer and planner") { + test("SPARK-28346: clone the query plan between different stages") { val tag1 = new TreeNodeTag[String]("a") val tag2 = new TreeNodeTag[String]("b") val tag3 = new TreeNodeTag[String]("c") - val tag4 = new TreeNodeTag[String]("d") def assertNoTag(tag: TreeNodeTag[String], plans: QueryPlan[_]*): Unit = { plans.foreach { plan => @@ -157,20 +156,39 @@ class QueryExecutionSuite extends SharedSQLContext { val df = spark.range(10) val analyzedPlan = df.queryExecution.analyzed + val cachedPlan = df.queryExecution.withCachedData val optimizedPlan = df.queryExecution.optimizedPlan - val physicalPlan = df.queryExecution.sparkPlan - val finalPlan = df.queryExecution.executedPlan analyzedPlan.setTagValue(tag1, "v") - assertNoTag(tag1, optimizedPlan, physicalPlan, finalPlan) + assertNoTag(tag1, cachedPlan, optimizedPlan) + + cachedPlan.setTagValue(tag2, "v") + assertNoTag(tag2, analyzedPlan, optimizedPlan) - optimizedPlan.setTagValue(tag2, "v") - assertNoTag(tag2, analyzedPlan, physicalPlan, finalPlan) + optimizedPlan.setTagValue(tag3, "v") + assertNoTag(tag3, analyzedPlan, cachedPlan) - physicalPlan.setTagValue(tag3, "v") - assertNoTag(tag3, analyzedPlan, optimizedPlan, finalPlan) + val tag4 = new TreeNodeTag[String]("d") + try { + spark.experimental.extraStrategies = Seq(new SparkStrategy() { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + plan.foreach { + case r: org.apache.spark.sql.catalyst.plans.logical.Range => + r.setTagValue(tag4, "v") + case _ => + } + Seq(FastOperator(plan.output)) + } + }) + // trigger planning + df.queryExecution.sparkPlan + assert(optimizedPlan.getTagValue(tag4).isEmpty) + } finally { + spark.experimental.extraStrategies = Nil + } - finalPlan.setTagValue(tag4, "v") - assertNoTag(tag4, analyzedPlan, optimizedPlan, physicalPlan) + val tag5 = new TreeNodeTag[String]("e") + df.queryExecution.executedPlan.setTagValue(tag5, "v") + assertNoTag(tag5, df.queryExecution.sparkPlan) } }