diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index bc0ca31dc635..1984ba18d343 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans +import java.util.IdentityHashMap + import scala.collection.mutable import org.apache.spark.sql.AnalysisException @@ -443,7 +445,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] override def verboseString(maxFields: Int): String = simpleString(maxFields) override def simpleStringWithNodeId(): String = { - val operatorId = getTagValue(QueryPlan.OP_ID_TAG).map(id => s"$id").getOrElse("unknown") + val operatorId = Option(QueryPlan.localIdMap.get().get(this)).map(id => s"$id") + .getOrElse("unknown") s"$nodeName ($operatorId)".trim } @@ -463,7 +466,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] } protected def formattedNodeName: String = { - val opId = getTagValue(QueryPlan.OP_ID_TAG).map(id => s"$id").getOrElse("unknown") + val opId = Option(QueryPlan.localIdMap.get().get(this)).map(id => s"$id") + .getOrElse("unknown") val codegenId = getTagValue(QueryPlan.CODEGEN_ID_TAG).map(id => s" [codegen id : $id]").getOrElse("") s"($opId) $nodeName$codegenId" @@ -675,9 +679,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] } object QueryPlan extends PredicateHelper { - val OP_ID_TAG = TreeNodeTag[Int]("operatorId") val CODEGEN_ID_TAG = new TreeNodeTag[Int]("wholeStageCodegenId") + /** + * A thread local map to store the mapping between the query plan and the query plan id. + * The scope of this thread local is within ExplainUtils.processPlan. The reason we define it here + * is because [[ QueryPlan ]] also needs this, and it doesn't have access to `execution` package + * from `catalyst`. + */ + val localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = ThreadLocal.withInitial(() => + new IdentityHashMap[QueryPlan[_], Int]()) + /** * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala index 11f6ae0e47ee..421a963453f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.execution -import java.util.Collections.newSetFromMap import java.util.IdentityHashMap -import java.util.Set import scala.collection.mutable.{ArrayBuffer, BitSet} @@ -30,6 +28,8 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveS import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec} object ExplainUtils extends AdaptiveSparkPlanHelper { + def localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = QueryPlan.localIdMap + /** * Given a input physical plan, performs the following tasks. * 1. Computes the whole stage codegen id for current operator and records it in the @@ -80,24 +80,26 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { * instances but cached plan is an exception. The `InMemoryRelation#innerChildren` use a shared * plan instance across multi-queries. Add lock for this method to avoid tag race condition. */ - def processPlan[T <: QueryPlan[T]](plan: T, append: String => Unit): Unit = synchronized { + def processPlan[T <: QueryPlan[T]](plan: T, append: String => Unit): Unit = { + val prevIdMap = localIdMap.get() try { - // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow - // intentional overwriting of IDs generated in previous AQE iteration - val operators = newSetFromMap[QueryPlan[_]](new IdentityHashMap()) + // Initialize a reference-unique id map to store generated ids, which also avoid accidental + // overwrites and to allow intentional overwriting of IDs generated in previous AQE iteration + val idMap = new IdentityHashMap[QueryPlan[_], Int]() + localIdMap.set(idMap) // Initialize an array of ReusedExchanges to help find Adaptively Optimized Out // Exchanges as part of SPARK-42753 val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec] var currentOperatorID = 0 - currentOperatorID = generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges, + currentOperatorID = generateOperatorIDs(plan, currentOperatorID, idMap, reusedExchanges, true) val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)] getSubqueries(plan, subqueries) currentOperatorID = subqueries.foldLeft(currentOperatorID) { - (curId, plan) => generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges, + (curId, plan) => generateOperatorIDs(plan._3.child, curId, idMap, reusedExchanges, true) } @@ -105,9 +107,9 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { val optimizedOutExchanges = ArrayBuffer.empty[Exchange] reusedExchanges.foreach{ reused => val child = reused.child - if (!operators.contains(child)) { + if (!idMap.containsKey(child)) { optimizedOutExchanges.append(child) - currentOperatorID = generateOperatorIDs(child, currentOperatorID, operators, + currentOperatorID = generateOperatorIDs(child, currentOperatorID, idMap, reusedExchanges, false) } } @@ -144,7 +146,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { append("\n") } } finally { - removeTags(plan) + localIdMap.set(prevIdMap) } } @@ -159,13 +161,15 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { * @param plan Input query plan to process * @param startOperatorID The start value of operation id. The subsequent operations will be * assigned higher value. - * @param visited A unique set of operators visited by generateOperatorIds. The set is scoped - * at the callsite function processPlan. It serves two purpose: Firstly, it is - * used to avoid accidentally overwriting existing IDs that were generated in - * the same processPlan call. Secondly, it is used to allow for intentional ID - * overwriting as part of SPARK-42753 where an Adaptively Optimized Out Exchange - * and its subtree may contain IDs that were generated in a previous AQE - * iteration's processPlan call which would result in incorrect IDs. + * @param idMap A reference-unique map store operators visited by generateOperatorIds and its + * id. This Map is scoped at the callsite function processPlan. It serves three + * purpose: + * Firstly, it stores the QueryPlan - generated ID mapping. Secondly, it is used to + * avoid accidentally overwriting existing IDs that were generated in the same + * processPlan call. Thirdly, it is used to allow for intentional ID overwriting as + * part of SPARK-42753 where an Adaptively Optimized Out Exchange and its subtree + * may contain IDs that were generated in a previous AQE iteration's processPlan + * call which would result in incorrect IDs. * @param reusedExchanges A unique set of ReusedExchange nodes visited which will be used to * idenitfy adaptively optimized out exchanges in SPARK-42753. * @param addReusedExchanges Whether to add ReusedExchange nodes to reusedExchanges set. We set it @@ -177,7 +181,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { private def generateOperatorIDs( plan: QueryPlan[_], startOperatorID: Int, - visited: Set[QueryPlan[_]], + idMap: java.util.Map[QueryPlan[_], Int], reusedExchanges: ArrayBuffer[ReusedExchangeExec], addReusedExchanges: Boolean): Int = { var currentOperationID = startOperatorID @@ -186,36 +190,35 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { return currentOperationID } - def setOpId(plan: QueryPlan[_]): Unit = if (!visited.contains(plan)) { + def setOpId(plan: QueryPlan[_]): Unit = idMap.computeIfAbsent(plan, plan => { plan match { case r: ReusedExchangeExec if addReusedExchanges => reusedExchanges.append(r) case _ => } - visited.add(plan) currentOperationID += 1 - plan.setTagValue(QueryPlan.OP_ID_TAG, currentOperationID) - } + currentOperationID + }) plan.foreachUp { case _: WholeStageCodegenExec => case _: InputAdapter => case p: AdaptiveSparkPlanExec => - currentOperationID = generateOperatorIDs(p.executedPlan, currentOperationID, visited, + currentOperationID = generateOperatorIDs(p.executedPlan, currentOperationID, idMap, reusedExchanges, addReusedExchanges) if (!p.executedPlan.fastEquals(p.initialPlan)) { - currentOperationID = generateOperatorIDs(p.initialPlan, currentOperationID, visited, + currentOperationID = generateOperatorIDs(p.initialPlan, currentOperationID, idMap, reusedExchanges, addReusedExchanges) } setOpId(p) case p: QueryStageExec => - currentOperationID = generateOperatorIDs(p.plan, currentOperationID, visited, + currentOperationID = generateOperatorIDs(p.plan, currentOperationID, idMap, reusedExchanges, addReusedExchanges) setOpId(p) case other: QueryPlan[_] => setOpId(other) currentOperationID = other.innerChildren.foldLeft(currentOperationID) { - (curId, plan) => generateOperatorIDs(plan, curId, visited, reusedExchanges, + (curId, plan) => generateOperatorIDs(plan, curId, idMap, reusedExchanges, addReusedExchanges) } } @@ -241,7 +244,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { } def collectOperatorWithID(plan: QueryPlan[_]): Unit = { - plan.getTagValue(QueryPlan.OP_ID_TAG).foreach { id => + Option(ExplainUtils.localIdMap.get().get(plan)).foreach { id => if (collectedOperators.add(id)) operators += plan } } @@ -334,20 +337,6 @@ object ExplainUtils extends AdaptiveSparkPlanHelper { * `operationId` tag value. */ def getOpId(plan: QueryPlan[_]): String = { - plan.getTagValue(QueryPlan.OP_ID_TAG).map(v => s"$v").getOrElse("unknown") - } - - def removeTags(plan: QueryPlan[_]): Unit = { - def remove(p: QueryPlan[_], children: Seq[QueryPlan[_]]): Unit = { - p.unsetTagValue(QueryPlan.OP_ID_TAG) - p.unsetTagValue(QueryPlan.CODEGEN_ID_TAG) - children.foreach(removeTags) - } - - plan foreach { - case p: AdaptiveSparkPlanExec => remove(p, Seq(p.executedPlan, p.initialPlan)) - case p: QueryStageExec => remove(p, Seq(p.plan)) - case plan: QueryPlan[_] => remove(plan, plan.innerChildren) - } + Option(ExplainUtils.localIdMap.get().get(plan)).map(v => s"$v").getOrElse("unknown") } }