Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
mapChildren(_.transformDown(rule))
} else {
// If the transform function replaces this node with a new one, carry over the tags.
afterRule.tags ++= this.tags
afterRule.copyTagsFrom(this)
afterRule.mapChildren(_.transformDown(rule))
}
}
Expand All @@ -311,7 +311,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}
}
// If the transform function replaces this node with a new one, carry over the tags.
newNode.tags ++= this.tags
newNode.copyTagsFrom(this)
newNode
}

Expand Down Expand Up @@ -429,8 +429,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
private def makeCopy(
newArgs: Array[AnyRef],
allowEmptyArgs: Boolean): BaseType = attachTree(this, "makeCopy") {
val allCtors = getClass.getConstructors
if (newArgs.isEmpty && allCtors.isEmpty) {
// This is a singleton object which doesn't have any constructor. Just return `this` as we
// can't copy it.
return this
}

// Skip no-arg constructors that are just there for kryo.
val ctors = getClass.getConstructors.filter(allowEmptyArgs || _.getParameterTypes.size != 0)
val ctors = allCtors.filter(allowEmptyArgs || _.getParameterTypes.size != 0)
if (ctors.isEmpty) {
sys.error(s"No valid constructor for $nodeName")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.StringUtils.{PlanStringConcat, StringConcat}
import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.adaptive.InsertAdaptiveSparkPlan
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
Expand Down Expand Up @@ -60,36 +60,38 @@ class QueryExecution(

lazy val analyzed: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.ANALYSIS) {
SparkSession.setActiveSession(sparkSession)
// We can't clone `logical` here, which will reset the `_analyzed` flag.
sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker)
}

lazy val withCachedData: LogicalPlan = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not necessary, but should we clone logical too before sending to analyzer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea I think we should

assertAnalyzed()
assertSupported()
sparkSession.sharedState.cacheManager.useCachedData(analyzed)
// clone the plan to avoid sharing the plan instance between different stages like analyzing,
// optimizing and planning.
sparkSession.sharedState.cacheManager.useCachedData(analyzed.clone())
}

lazy val optimizedPlan: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.OPTIMIZATION) {
sparkSession.sessionState.optimizer.executeAndTrack(withCachedData, tracker)
// clone the plan to avoid sharing the plan instance between different stages like analyzing,
// optimizing and planning.
sparkSession.sessionState.optimizer.executeAndTrack(withCachedData.clone(), tracker)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since now query plan is mutable, I think it's better to limit the life cycle of a query plan instance. We can clone the query plan between analyzer, optimizer and planner, so that the life cycle is limited in one stage.

If we decide to clone the plan after each stage, will any test fail if we do not clone it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test added

}

lazy val sparkPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) {
SparkSession.setActiveSession(sparkSession)
// Runtime re-optimization requires a unique instance of every node in the logical plan.
val logicalPlan = if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) {
optimizedPlan.clone()
} else {
optimizedPlan
}
// TODO: We use next(), i.e. take the first plan returned by the planner, here for now,
// but we will implement to choose the best plan.
planner.plan(ReturnAnswer(logicalPlan)).next()
// Clone the logical plan here, in case the planner rules change the states of the logical plan.
planner.plan(ReturnAnswer(optimizedPlan.clone())).next()
}

// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
lazy val executedPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) {
prepareForExecution(sparkPlan)
// clone the plan to avoid sharing the plan instance between different stages like analyzing,
// optimizing and planning.
prepareForExecution(sparkPlan.clone())
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,13 @@ case class InMemoryRelation(
statsOfPlanToCache).asInstanceOf[this.type]
}

// override `clone` since the default implementation won't carry over mutable states.
override def clone(): LogicalPlan = {
val cloned = this.copy()
cloned.statsOfPlanToCache = this.statsOfPlanToCache
cloned
}

override def simpleString(maxFields: Int): String =
s"InMemoryRelation [${truncatedString(output, ", ", maxFields)}], ${cacheBuilder.storageLevel}"
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.IgnoreCachedData
import org.apache.spark.sql.catalyst.plans.logical.{IgnoreCachedData, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.types.{StringType, StructField, StructType}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,10 @@ case class SaveIntoDataSourceCommand(
val redacted = SQLConf.get.redactOptions(options)
s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}"
}

// Override `clone` since the default implementation will turn `CaseInsensitiveMap` to a normal
// map.
override def clone(): LogicalPlan = {
SaveIntoDataSourceCommand(query.clone(), dataSource, options, mode)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mapChildren in TreeNode will change the map type. (from CaseInsensitiveMap to a normal map)

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ package org.apache.spark.sql.execution

import scala.io.Source

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.{AnalysisException, FastOperator}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext

Expand Down Expand Up @@ -137,5 +139,56 @@ class QueryExecutionSuite extends SharedSQLContext {
(_: LogicalPlan) => throw new Error("error"))
val error = intercept[Error](qe.toString)
assert(error.getMessage.contains("error"))

spark.experimental.extraStrategies = Nil
}

test("SPARK-28346: clone the query plan between different stages") {
val tag1 = new TreeNodeTag[String]("a")
val tag2 = new TreeNodeTag[String]("b")
val tag3 = new TreeNodeTag[String]("c")

def assertNoTag(tag: TreeNodeTag[String], plans: QueryPlan[_]*): Unit = {
plans.foreach { plan =>
assert(plan.getTagValue(tag).isEmpty)
}
}

val df = spark.range(10)
val analyzedPlan = df.queryExecution.analyzed
val cachedPlan = df.queryExecution.withCachedData
val optimizedPlan = df.queryExecution.optimizedPlan

analyzedPlan.setTagValue(tag1, "v")
assertNoTag(tag1, cachedPlan, optimizedPlan)

cachedPlan.setTagValue(tag2, "v")
assertNoTag(tag2, analyzedPlan, optimizedPlan)

optimizedPlan.setTagValue(tag3, "v")
assertNoTag(tag3, analyzedPlan, cachedPlan)

val tag4 = new TreeNodeTag[String]("d")
try {
spark.experimental.extraStrategies = Seq(new SparkStrategy() {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
plan.foreach {
case r: org.apache.spark.sql.catalyst.plans.logical.Range =>
r.setTagValue(tag4, "v")
case _ =>
}
Seq(FastOperator(plan.output))
}
})
// trigger planning
df.queryExecution.sparkPlan
assert(optimizedPlan.getTagValue(tag4).isEmpty)
} finally {
spark.experimental.extraStrategies = Nil
}

val tag5 = new TreeNodeTag[String]("e")
df.queryExecution.executedPlan.setTagValue(tag5, "v")
assertNoTag(tag5, df.queryExecution.sparkPlan)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class PartitionBatchPruningSuite
val result = df.collect().map(_(0)).toArray
assert(result.length === 1)

val (readPartitions, readBatches) = df.queryExecution.sparkPlan.collect {
val (readPartitions, readBatches) = df.queryExecution.executedPlan.collect {
case in: InMemoryTableScanExec => (in.readPartitions.value, in.readBatches.value)
}.head
assert(readPartitions === 5)
Expand All @@ -208,7 +208,7 @@ class PartitionBatchPruningSuite
df.collect().map(_(0)).toArray
}

val (readPartitions, readBatches) = df.queryExecution.sparkPlan.collect {
val (readPartitions, readBatches) = df.queryExecution.executedPlan.collect {
case in: InMemoryTableScanExec => (in.readPartitions.value, in.readBatches.value)
}.head

Expand Down