Skip to content

Commit e0cd9ce

Browse files
committed
make TreeNode tag type safe
1 parent 5fae8f7 commit e0cd9ce

File tree

6 files changed

+36
-23
lines changed

6 files changed

+36
-23
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ case class OneRowRelation() extends LeafNode {
10831083
/** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */
10841084
override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = {
10851085
val newCopy = OneRowRelation()
1086-
newCopy.tags ++= this.tags
1086+
newCopy.copyTagsFrom(this)
10871087
newCopy
10881088
}
10891089
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,8 @@ object CurrentOrigin {
7474
}
7575
}
7676

77-
// The name of the tree node tag. This is preferred over using string directly, as we can easily
78-
// find all the defined tags.
79-
case class TreeNodeTagName(name: String)
77+
// A tag of a `TreeNode`, which defines name and type
78+
case class TreeNodeTag[T](name: String)
8079

8180
// scalastyle:off
8281
abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
@@ -89,7 +88,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
8988
* A mutable map for holding auxiliary information of this tree node. It will be carried over
9089
* when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`.
9190
*/
92-
val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty
91+
private val tags: mutable.Map[TreeNodeTag[_], Any] = mutable.Map.empty
92+
93+
protected def copyTagsFrom(other: BaseType): Unit = {
94+
tags ++= other.tags
95+
}
96+
97+
def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = {
98+
tags(tag) = value
99+
}
100+
101+
def getTagValue[T](tag: TreeNodeTag[T]): Option[T] = {
102+
tags.get(tag).map(_.asInstanceOf[T])
103+
}
93104

94105
/**
95106
* Returns a Seq of the children of this node.
@@ -418,7 +429,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
418429
try {
419430
CurrentOrigin.withOrigin(origin) {
420431
val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
421-
res.tags ++= this.tags
432+
res.copyTagsFrom(this)
422433
res
423434
}
424435
} catch {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -622,31 +622,33 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
622622
}
623623

624624
test("tags will be carried over after copy & transform") {
625+
val tag = TreeNodeTag[String]("test")
626+
625627
withClue("makeCopy") {
626628
val node = Dummy(None)
627-
node.tags += TreeNodeTagName("test") -> "a"
629+
node.setTagValue(tag, "a")
628630
val copied = node.makeCopy(Array(Some(Literal(1))))
629-
assert(copied.tags(TreeNodeTagName("test")) == "a")
631+
assert(copied.getTagValue(tag) == Some("a"))
630632
}
631633

632634
def checkTransform(
633635
sameTypeTransform: Expression => Expression,
634636
differentTypeTransform: Expression => Expression): Unit = {
635637
val child = Dummy(None)
636-
child.tags += TreeNodeTagName("test") -> "child"
638+
child.setTagValue(tag, "child")
637639
val node = Dummy(Some(child))
638-
node.tags += TreeNodeTagName("test") -> "parent"
640+
node.setTagValue(tag, "parent")
639641

640642
val transformed = sameTypeTransform(node)
641643
// Both the child and parent keep the tags
642-
assert(transformed.tags(TreeNodeTagName("test")) == "parent")
643-
assert(transformed.children.head.tags(TreeNodeTagName("test")) == "child")
644+
assert(transformed.getTagValue(tag) == Some("parent"))
645+
assert(transformed.children.head.getTagValue(tag) == Some("child"))
644646

645647
val transformed2 = differentTypeTransform(node)
646648
// Both the child and parent keep the tags, even if we transform the node to a new one of
647649
// different type.
648-
assert(transformed2.tags(TreeNodeTagName("test")) == "parent")
649-
assert(transformed2.children.head.tags.contains(TreeNodeTagName("test")))
650+
assert(transformed2.getTagValue(tag) == Some("parent"))
651+
assert(transformed2.children.head.getTagValue(tag) == Some("child"))
650652
}
651653

652654
withClue("transformDown") {

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,16 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
3333
import org.apache.spark.sql.catalyst.expressions._
3434
import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _}
3535
import org.apache.spark.sql.catalyst.plans.QueryPlan
36+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3637
import org.apache.spark.sql.catalyst.plans.physical._
37-
import org.apache.spark.sql.catalyst.trees.TreeNodeTagName
38+
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
3839
import org.apache.spark.sql.execution.metric.SQLMetric
3940
import org.apache.spark.sql.types.DataType
4041

4142
object SparkPlan {
4243
// a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag
4344
// when converting a logical plan to a physical plan.
44-
val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan")
45+
val LOGICAL_PLAN_TAG = TreeNodeTag[LogicalPlan]("logical_plan")
4546
}
4647

4748
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
6969
case ReturnAnswer(rootPlan) => rootPlan
7070
case _ => plan
7171
}
72-
p.tags += SparkPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan
72+
p.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, logicalPlan)
7373
p
7474
}
7575
}

sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import scala.reflect.ClassTag
2121

2222
import org.apache.spark.sql.TPCDSQuerySuite
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final}
24-
import org.apache.spark.sql.catalyst.plans.QueryPlan
2524
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window}
2625
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
2726
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
@@ -81,12 +80,12 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite {
8180
// The exchange related nodes are created after the planning, they don't have corresponding
8281
// logical plan.
8382
case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec =>
84-
assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))
83+
assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty)
8584

8685
// The subquery exec nodes are just wrappers of the actual nodes, they don't have
8786
// corresponding logical plan.
8887
case _: SubqueryExec | _: ReusedSubqueryExec =>
89-
assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))
88+
assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty)
9089

9190
case _ if isScanPlanTree(plan) =>
9291
// The strategies for planning scan can remove or add FilterExec/ProjectExec nodes,
@@ -120,9 +119,9 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite {
120119
}
121120

122121
private def getLogicalPlan(node: SparkPlan): LogicalPlan = {
123-
assert(node.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME),
124-
node.getClass.getSimpleName + " does not have a logical plan link")
125-
node.tags(SparkPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan]
122+
node.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).getOrElse {
123+
fail(node.getClass.getSimpleName + " does not have a logical plan link")
124+
}
126125
}
127126

128127
private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: SparkPlan): Unit = {

0 commit comments

Comments
 (0)