Skip to content

Commit 8a6a490

Browse files
cloud-fanpengbo
andcommitted
[SPARK-27747][SQL] add a logical plan link in the physical plan
It's pretty useful if we can convert a physical plan back to a logical plan, e.g., in apache#24389 This PR introduces a new feature to `TreeNode`, which allows `TreeNode` to carry some extra information via a mutable map, and keep the information when it's copied. The planner leverages this feature to put the logical plan into the physical plan. a test suite that runs all TPCDS queries and checks that some common physical plans contain the corresponding logical plans. Closes apache#24626 from cloud-fan/link. Lead-authored-by: Wenchen Fan <wenchen@databricks.com> Co-authored-by: Peng Bo <bo.peng1019@gmail.com> Signed-off-by: gatorsmile <gatorsmile@gmail.com>
1 parent c7c2bda commit 8a6a490

File tree

6 files changed

+227
-6
lines changed

6 files changed

+227
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,11 @@ case class OneRowRelation() extends LeafNode {
996996
override def computeStats(): Statistics = Statistics(sizeInBytes = 1)
997997

998998
/** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */
999-
override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = OneRowRelation()
999+
override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = {
1000+
val newCopy = OneRowRelation()
1001+
newCopy.tags ++= this.tags
1002+
newCopy
1003+
}
10001004
}
10011005

10021006
/** A logical plan for `dropDuplicates`. */

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.trees
1919

2020
import java.util.UUID
2121

22-
import scala.collection.Map
22+
import scala.collection.{mutable, Map}
2323
import scala.reflect.ClassTag
2424

2525
import org.apache.commons.lang3.ClassUtils
@@ -71,13 +71,23 @@ object CurrentOrigin {
7171
}
7272
}
7373

74+
// The name of the tree node tag. This is preferred over using string directly, as we can easily
75+
// find all the defined tags.
76+
case class TreeNodeTagName(name: String)
77+
7478
// scalastyle:off
7579
abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
7680
// scalastyle:on
7781
self: BaseType =>
7882

7983
val origin: Origin = CurrentOrigin.get
8084

85+
/**
86+
* A mutable map for holding auxiliary information of this tree node. It will be carried over
87+
* when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`.
88+
*/
89+
val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty
90+
8191
/**
8292
* Returns a Seq of the children of this node.
8393
* Children should not change. Immutability required for containsChild optimization
@@ -262,6 +272,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
262272
if (this fastEquals afterRule) {
263273
mapChildren(_.transformDown(rule))
264274
} else {
275+
// If the transform function replaces this node with a new one, carry over the tags.
276+
afterRule.tags ++= this.tags
265277
afterRule.mapChildren(_.transformDown(rule))
266278
}
267279
}
@@ -275,7 +287,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
275287
*/
276288
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
277289
val afterRuleOnChildren = mapChildren(_.transformUp(rule))
278-
if (this fastEquals afterRuleOnChildren) {
290+
val newNode = if (this fastEquals afterRuleOnChildren) {
279291
CurrentOrigin.withOrigin(origin) {
280292
rule.applyOrElse(this, identity[BaseType])
281293
}
@@ -284,6 +296,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
284296
rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
285297
}
286298
}
299+
// If the transform function replaces this node with a new one, carry over the tags.
300+
newNode.tags ++= this.tags
301+
newNode
287302
}
288303

289304
/**
@@ -402,7 +417,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
402417

403418
try {
404419
CurrentOrigin.withOrigin(origin) {
405-
defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
420+
val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
421+
res.tags ++= this.tags
422+
res
406423
}
407424
} catch {
408425
case e: java.lang.IllegalArgumentException =>

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,4 +617,55 @@ class TreeNodeSuite extends SparkFunSuite {
617617
val expected = Coalesce(Stream(Literal(1), Literal(3)))
618618
assert(result === expected)
619619
}
620+
621+
test("tags will be carried over after copy & transform") {
622+
withClue("makeCopy") {
623+
val node = Dummy(None)
624+
node.tags += TreeNodeTagName("test") -> "a"
625+
val copied = node.makeCopy(Array(Some(Literal(1))))
626+
assert(copied.tags(TreeNodeTagName("test")) == "a")
627+
}
628+
629+
def checkTransform(
630+
sameTypeTransform: Expression => Expression,
631+
differentTypeTransform: Expression => Expression): Unit = {
632+
val child = Dummy(None)
633+
child.tags += TreeNodeTagName("test") -> "child"
634+
val node = Dummy(Some(child))
635+
node.tags += TreeNodeTagName("test") -> "parent"
636+
637+
val transformed = sameTypeTransform(node)
638+
// Both the child and parent keep the tags
639+
assert(transformed.tags(TreeNodeTagName("test")) == "parent")
640+
assert(transformed.children.head.tags(TreeNodeTagName("test")) == "child")
641+
642+
val transformed2 = differentTypeTransform(node)
643+
// Both the child and parent keep the tags, even if we transform the node to a new one of
644+
// different type.
645+
assert(transformed2.tags(TreeNodeTagName("test")) == "parent")
646+
assert(transformed2.children.head.tags.contains(TreeNodeTagName("test")))
647+
}
648+
649+
withClue("transformDown") {
650+
checkTransform(
651+
sameTypeTransform = _ transformDown {
652+
case Dummy(None) => Dummy(Some(Literal(1)))
653+
},
654+
differentTypeTransform = _ transformDown {
655+
case Dummy(None) => Literal(1)
656+
657+
})
658+
}
659+
660+
withClue("transformUp") {
661+
checkTransform(
662+
sameTypeTransform = _ transformUp {
663+
case Dummy(None) => Dummy(Some(Literal(1)))
664+
},
665+
differentTypeTransform = _ transformUp {
666+
case Dummy(None) => Literal(1)
667+
668+
})
669+
}
670+
}
620671
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution
2020
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
2121

2222
import scala.collection.mutable.ArrayBuffer
23-
import scala.concurrent.ExecutionContext
2423

2524
import org.codehaus.commons.compiler.CompileException
2625
import org.codehaus.janino.InternalCompilerException
@@ -35,9 +34,15 @@ import org.apache.spark.sql.catalyst.expressions._
3534
import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _}
3635
import org.apache.spark.sql.catalyst.plans.QueryPlan
3736
import org.apache.spark.sql.catalyst.plans.physical._
37+
import org.apache.spark.sql.catalyst.trees.TreeNodeTagName
3838
import org.apache.spark.sql.execution.metric.SQLMetric
3939
import org.apache.spark.sql.types.DataType
40-
import org.apache.spark.util.ThreadUtils
40+
41+
object SparkPlan {
42+
// a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag
43+
// when converting a logical plan to a physical plan.
44+
val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan")
45+
}
4146

4247
/**
4348
* The base class for physical operators.

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,17 @@ case class PlanLater(plan: LogicalPlan) extends LeafExecNode {
6262
abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
6363
self: SparkPlanner =>
6464

65+
override def plan(plan: LogicalPlan): Iterator[SparkPlan] = {
66+
super.plan(plan).map { p =>
67+
val logicalPlan = plan match {
68+
case ReturnAnswer(rootPlan) => rootPlan
69+
case _ => plan
70+
}
71+
p.tags += SparkPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan
72+
p
73+
}
74+
}
75+
6576
/**
6677
* Plans special cases of limit operators.
6778
*/
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import scala.reflect.ClassTag
21+
22+
import org.apache.spark.sql.TPCDSQuerySuite
23+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final}
24+
import org.apache.spark.sql.catalyst.plans.QueryPlan
25+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window}
26+
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
27+
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
28+
import org.apache.spark.sql.execution.datasources.LogicalRelation
29+
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation}
30+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
31+
import org.apache.spark.sql.execution.joins._
32+
import org.apache.spark.sql.execution.window.WindowExec
33+
34+
class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite {
35+
36+
override protected def checkGeneratedCode(plan: SparkPlan): Unit = {
37+
super.checkGeneratedCode(plan)
38+
checkLogicalPlanTag(plan)
39+
}
40+
41+
private def isFinalAgg(aggExprs: Seq[AggregateExpression]): Boolean = {
42+
// TODO: aggregate node without aggregate expressions can also be a final aggregate, but
43+
// currently the aggregate node doesn't have a final/partial flag.
44+
aggExprs.nonEmpty && aggExprs.forall(ae => ae.mode == Complete || ae.mode == Final)
45+
}
46+
47+
// A scan plan tree is a plan tree that has a leaf node under zero or more Project/Filter nodes.
48+
private def isScanPlanTree(plan: SparkPlan): Boolean = plan match {
49+
case p: ProjectExec => isScanPlanTree(p.child)
50+
case f: FilterExec => isScanPlanTree(f.child)
51+
case _: LeafExecNode => true
52+
case _ => false
53+
}
54+
55+
private def checkLogicalPlanTag(plan: SparkPlan): Unit = {
56+
plan match {
57+
case _: HashJoin | _: BroadcastNestedLoopJoinExec | _: CartesianProductExec
58+
| _: ShuffledHashJoinExec | _: SortMergeJoinExec =>
59+
assertLogicalPlanType[Join](plan)
60+
61+
// There is no corresponding logical plan for the physical partial aggregate.
62+
case agg: HashAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
63+
assertLogicalPlanType[Aggregate](plan)
64+
case agg: ObjectHashAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
65+
assertLogicalPlanType[Aggregate](plan)
66+
case agg: SortAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
67+
assertLogicalPlanType[Aggregate](plan)
68+
69+
case _: WindowExec =>
70+
assertLogicalPlanType[Window](plan)
71+
72+
case _: UnionExec =>
73+
assertLogicalPlanType[Union](plan)
74+
75+
case _: SampleExec =>
76+
assertLogicalPlanType[Sample](plan)
77+
78+
case _: GenerateExec =>
79+
assertLogicalPlanType[Generate](plan)
80+
81+
// The exchange related nodes are created after the planning, they don't have corresponding
82+
// logical plan.
83+
case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec =>
84+
assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))
85+
86+
// The subquery exec nodes are just wrappers of the actual nodes, they don't have
87+
// corresponding logical plan.
88+
case _: SubqueryExec | _: ReusedSubqueryExec =>
89+
assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))
90+
91+
case _ if isScanPlanTree(plan) =>
92+
// The strategies for planning scan can remove or add FilterExec/ProjectExec nodes,
93+
// so it's not simple to check. Instead, we only check that the origin LogicalPlan
94+
// contains the corresponding leaf node of the SparkPlan.
95+
// a strategy might remove the filter if it's totally pushed down, e.g.:
96+
// logical = Project(Filter(Scan A))
97+
// physical = ProjectExec(ScanExec A)
98+
// we only check that leaf modes match between logical and physical plan.
99+
val logicalLeaves = getLogicalPlan(plan).collectLeaves()
100+
val physicalLeaves = plan.collectLeaves()
101+
assert(logicalLeaves.length == 1)
102+
assert(physicalLeaves.length == 1)
103+
physicalLeaves.head match {
104+
case _: RangeExec => logicalLeaves.head.isInstanceOf[Range]
105+
case _: DataSourceScanExec => logicalLeaves.head.isInstanceOf[LogicalRelation]
106+
case _: InMemoryTableScanExec => logicalLeaves.head.isInstanceOf[InMemoryRelation]
107+
case _: LocalTableScanExec => logicalLeaves.head.isInstanceOf[LocalRelation]
108+
case _: ExternalRDDScanExec[_] => logicalLeaves.head.isInstanceOf[ExternalRDD[_]]
109+
case _: BatchScanExec => logicalLeaves.head.isInstanceOf[DataSourceV2Relation]
110+
case _ =>
111+
}
112+
// Do not need to check the children recursively.
113+
return
114+
115+
case _ =>
116+
}
117+
118+
plan.children.foreach(checkLogicalPlanTag)
119+
plan.subqueries.foreach(checkLogicalPlanTag)
120+
}
121+
122+
private def getLogicalPlan(node: SparkPlan): LogicalPlan = {
123+
assert(node.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME),
124+
node.getClass.getSimpleName + " does not have a logical plan link")
125+
node.tags(SparkPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan]
126+
}
127+
128+
private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: SparkPlan): Unit = {
129+
val logicalPlan = getLogicalPlan(node)
130+
val expectedCls = implicitly[ClassTag[T]].runtimeClass
131+
assert(expectedCls == logicalPlan.getClass)
132+
}
133+
}

0 commit comments

Comments
 (0)