From 8a6a490f83df609be8230682d6aaf593806f75c2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 20 May 2019 13:42:25 -0700 Subject: [PATCH 1/5] [SPARK-27747][SQL] add a logical plan link in the physical plan It's pretty useful if we can convert a physical plan back to a logical plan, e.g., in https://github.com/apache/spark/pull/24389 This PR introduces a new feature to `TreeNode`, which allows `TreeNode` to carry some extra information via a mutable map, and keep the information when it's copied. The planner leverages this feature to put the logical plan into the physical plan. a test suite that runs all TPCDS queries and checks that some common physical plans contain the corresponding logical plans. Closes #24626 from cloud-fan/link. Lead-authored-by: Wenchen Fan Co-authored-by: Peng Bo Signed-off-by: gatorsmile --- .../plans/logical/basicLogicalOperators.scala | 6 +- .../spark/sql/catalyst/trees/TreeNode.scala | 23 ++- .../sql/catalyst/trees/TreeNodeSuite.scala | 51 +++++++ .../spark/sql/execution/SparkPlan.scala | 9 +- .../spark/sql/execution/SparkStrategies.scala | 11 ++ .../LogicalPlanTagInSparkPlanSuite.scala | 133 ++++++++++++++++++ 6 files changed, 227 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 1c255861ad3a..68a0fa499efb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -996,7 +996,11 @@ case class OneRowRelation() extends LeafNode { override def computeStats(): Statistics = Statistics(sizeInBytes = 1) /** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */ - override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = OneRowRelation() + override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = { + val newCopy = OneRowRelation() + newCopy.tags ++= this.tags + newCopy + } } /** A logical plan for `dropDuplicates`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index a924f10fb366..957c5c24b08a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.trees import java.util.UUID -import scala.collection.Map +import scala.collection.{mutable, Map} import scala.reflect.ClassTag import org.apache.commons.lang3.ClassUtils @@ -71,6 +71,10 @@ object CurrentOrigin { } } +// The name of the tree node tag. This is preferred over using string directly, as we can easily +// find all the defined tags. +case class TreeNodeTagName(name: String) + // scalastyle:off abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { // scalastyle:on @@ -78,6 +82,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { val origin: Origin = CurrentOrigin.get + /** + * A mutable map for holding auxiliary information of this tree node. It will be carried over + * when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`. + */ + val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty + /** * Returns a Seq of the children of this node. * Children should not change. Immutability required for containsChild optimization @@ -262,6 +272,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { if (this fastEquals afterRule) { mapChildren(_.transformDown(rule)) } else { + // If the transform function replaces this node with a new one, carry over the tags. + afterRule.tags ++= this.tags afterRule.mapChildren(_.transformDown(rule)) } } @@ -275,7 +287,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { */ def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { val afterRuleOnChildren = mapChildren(_.transformUp(rule)) - if (this fastEquals afterRuleOnChildren) { + val newNode = if (this fastEquals afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, identity[BaseType]) } @@ -284,6 +296,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) } } + // If the transform function replaces this node with a new one, carry over the tags. + newNode.tags ++= this.tags + newNode } /** @@ -402,7 +417,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { try { CurrentOrigin.withOrigin(origin) { - defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType] + val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType] + res.tags ++= this.tags + res } } catch { case e: java.lang.IllegalArgumentException => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index e37cf8a8e217..d5619345f793 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -617,4 +617,55 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Coalesce(Stream(Literal(1), Literal(3))) assert(result === expected) } + + test("tags will be carried over after copy & transform") { + withClue("makeCopy") { + val node = Dummy(None) + node.tags += TreeNodeTagName("test") -> "a" + val copied = node.makeCopy(Array(Some(Literal(1)))) + assert(copied.tags(TreeNodeTagName("test")) == "a") + } + + def checkTransform( + sameTypeTransform: Expression => Expression, + differentTypeTransform: Expression => Expression): Unit = { + val child = Dummy(None) + child.tags += TreeNodeTagName("test") -> "child" + val node = Dummy(Some(child)) + node.tags += TreeNodeTagName("test") -> "parent" + + val transformed = sameTypeTransform(node) + // Both the child and parent keep the tags + assert(transformed.tags(TreeNodeTagName("test")) == "parent") + assert(transformed.children.head.tags(TreeNodeTagName("test")) == "child") + + val transformed2 = differentTypeTransform(node) + // Both the child and parent keep the tags, even if we transform the node to a new one of + // different type. + assert(transformed2.tags(TreeNodeTagName("test")) == "parent") + assert(transformed2.children.head.tags.contains(TreeNodeTagName("test"))) + } + + withClue("transformDown") { + checkTransform( + sameTypeTransform = _ transformDown { + case Dummy(None) => Dummy(Some(Literal(1))) + }, + differentTypeTransform = _ transformDown { + case Dummy(None) => Literal(1) + + }) + } + + withClue("transformUp") { + checkTransform( + sameTypeTransform = _ transformUp { + case Dummy(None) => Dummy(Some(Literal(1))) + }, + differentTypeTransform = _ transformUp { + case Dummy(None) => Literal(1) + + }) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 7646f9613efb..967d5f640143 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext import org.codehaus.commons.compiler.CompileException import org.codehaus.janino.InternalCompilerException @@ -35,9 +34,15 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.trees.TreeNodeTagName import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.DataType -import org.apache.spark.util.ThreadUtils + +object SparkPlan { + // a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag + // when converting a logical plan to a physical plan. + val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan") +} /** * The base class for physical operators. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dbc6db62bd82..44b8db6645b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -62,6 +62,17 @@ case class PlanLater(plan: LogicalPlan) extends LeafExecNode { abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => + override def plan(plan: LogicalPlan): Iterator[SparkPlan] = { + super.plan(plan).map { p => + val logicalPlan = plan match { + case ReturnAnswer(rootPlan) => rootPlan + case _ => plan + } + p.tags += SparkPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan + p + } + } + /** * Plans special cases of limit operators. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala new file mode 100644 index 000000000000..ca7ced5ef538 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.reflect.ClassTag + +import org.apache.spark.sql.TPCDSQuerySuite +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window} +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.window.WindowExec + +class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite { + + override protected def checkGeneratedCode(plan: SparkPlan): Unit = { + super.checkGeneratedCode(plan) + checkLogicalPlanTag(plan) + } + + private def isFinalAgg(aggExprs: Seq[AggregateExpression]): Boolean = { + // TODO: aggregate node without aggregate expressions can also be a final aggregate, but + // currently the aggregate node doesn't have a final/partial flag. + aggExprs.nonEmpty && aggExprs.forall(ae => ae.mode == Complete || ae.mode == Final) + } + + // A scan plan tree is a plan tree that has a leaf node under zero or more Project/Filter nodes. + private def isScanPlanTree(plan: SparkPlan): Boolean = plan match { + case p: ProjectExec => isScanPlanTree(p.child) + case f: FilterExec => isScanPlanTree(f.child) + case _: LeafExecNode => true + case _ => false + } + + private def checkLogicalPlanTag(plan: SparkPlan): Unit = { + plan match { + case _: HashJoin | _: BroadcastNestedLoopJoinExec | _: CartesianProductExec + | _: ShuffledHashJoinExec | _: SortMergeJoinExec => + assertLogicalPlanType[Join](plan) + + // There is no corresponding logical plan for the physical partial aggregate. + case agg: HashAggregateExec if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + case agg: ObjectHashAggregateExec if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + case agg: SortAggregateExec if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + + case _: WindowExec => + assertLogicalPlanType[Window](plan) + + case _: UnionExec => + assertLogicalPlanType[Union](plan) + + case _: SampleExec => + assertLogicalPlanType[Sample](plan) + + case _: GenerateExec => + assertLogicalPlanType[Generate](plan) + + // The exchange related nodes are created after the planning, they don't have corresponding + // logical plan. + case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec => + assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME)) + + // The subquery exec nodes are just wrappers of the actual nodes, they don't have + // corresponding logical plan. + case _: SubqueryExec | _: ReusedSubqueryExec => + assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME)) + + case _ if isScanPlanTree(plan) => + // The strategies for planning scan can remove or add FilterExec/ProjectExec nodes, + // so it's not simple to check. Instead, we only check that the origin LogicalPlan + // contains the corresponding leaf node of the SparkPlan. + // a strategy might remove the filter if it's totally pushed down, e.g.: + // logical = Project(Filter(Scan A)) + // physical = ProjectExec(ScanExec A) + // we only check that leaf modes match between logical and physical plan. + val logicalLeaves = getLogicalPlan(plan).collectLeaves() + val physicalLeaves = plan.collectLeaves() + assert(logicalLeaves.length == 1) + assert(physicalLeaves.length == 1) + physicalLeaves.head match { + case _: RangeExec => logicalLeaves.head.isInstanceOf[Range] + case _: DataSourceScanExec => logicalLeaves.head.isInstanceOf[LogicalRelation] + case _: InMemoryTableScanExec => logicalLeaves.head.isInstanceOf[InMemoryRelation] + case _: LocalTableScanExec => logicalLeaves.head.isInstanceOf[LocalRelation] + case _: ExternalRDDScanExec[_] => logicalLeaves.head.isInstanceOf[ExternalRDD[_]] + case _: BatchScanExec => logicalLeaves.head.isInstanceOf[DataSourceV2Relation] + case _ => + } + // Do not need to check the children recursively. + return + + case _ => + } + + plan.children.foreach(checkLogicalPlanTag) + plan.subqueries.foreach(checkLogicalPlanTag) + } + + private def getLogicalPlan(node: SparkPlan): LogicalPlan = { + assert(node.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME), + node.getClass.getSimpleName + " does not have a logical plan link") + node.tags(SparkPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan] + } + + private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: SparkPlan): Unit = { + val logicalPlan = getLogicalPlan(node) + val expectedCls = implicitly[ClassTag[T]].runtimeClass + assert(expectedCls == logicalPlan.getClass) + } +} From f45352e2de04ef2aae07af1f8a51fe79d227e61d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 23 May 2019 11:53:21 -0700 Subject: [PATCH 2/5] [SPARK-27816][SQL] make TreeNode tag type safe ## What changes were proposed in this pull request? Add type parameter to `TreeNodeTag`. ## How was this patch tested? existing tests Closes #24687 from cloud-fan/tag. Authored-by: Wenchen Fan Signed-off-by: gatorsmile --- .../plans/logical/basicLogicalOperators.scala | 2 +- .../spark/sql/catalyst/trees/TreeNode.scala | 21 ++++++++++++++----- .../sql/catalyst/trees/TreeNodeSuite.scala | 18 +++++++++------- .../spark/sql/execution/SparkPlan.scala | 5 +++-- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../LogicalPlanTagInSparkPlanSuite.scala | 11 +++++----- 6 files changed, 36 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 68a0fa499efb..94deacee3c87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -998,7 +998,7 @@ case class OneRowRelation() extends LeafNode { /** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */ override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = { val newCopy = OneRowRelation() - newCopy.tags ++= this.tags + newCopy.copyTagsFrom(this) newCopy } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 957c5c24b08a..e094bfbb3ec7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -71,9 +71,8 @@ object CurrentOrigin { } } -// The name of the tree node tag. This is preferred over using string directly, as we can easily -// find all the defined tags. -case class TreeNodeTagName(name: String) +// A tag of a `TreeNode`, which defines name and type +case class TreeNodeTag[T](name: String) // scalastyle:off abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { @@ -86,7 +85,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * A mutable map for holding auxiliary information of this tree node. It will be carried over * when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`. */ - val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty + private val tags: mutable.Map[TreeNodeTag[_], Any] = mutable.Map.empty + + protected def copyTagsFrom(other: BaseType): Unit = { + tags ++= other.tags + } + + def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = { + tags(tag) = value + } + + def getTagValue[T](tag: TreeNodeTag[T]): Option[T] = { + tags.get(tag).map(_.asInstanceOf[T]) + } /** * Returns a Seq of the children of this node. @@ -418,7 +429,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { try { CurrentOrigin.withOrigin(origin) { val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType] - res.tags ++= this.tags + res.copyTagsFrom(this) res } } catch { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index d5619345f793..e94610e59c7a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -619,31 +619,33 @@ class TreeNodeSuite extends SparkFunSuite { } test("tags will be carried over after copy & transform") { + val tag = TreeNodeTag[String]("test") + withClue("makeCopy") { val node = Dummy(None) - node.tags += TreeNodeTagName("test") -> "a" + node.setTagValue(tag, "a") val copied = node.makeCopy(Array(Some(Literal(1)))) - assert(copied.tags(TreeNodeTagName("test")) == "a") + assert(copied.getTagValue(tag) == Some("a")) } def checkTransform( sameTypeTransform: Expression => Expression, differentTypeTransform: Expression => Expression): Unit = { val child = Dummy(None) - child.tags += TreeNodeTagName("test") -> "child" + child.setTagValue(tag, "child") val node = Dummy(Some(child)) - node.tags += TreeNodeTagName("test") -> "parent" + node.setTagValue(tag, "parent") val transformed = sameTypeTransform(node) // Both the child and parent keep the tags - assert(transformed.tags(TreeNodeTagName("test")) == "parent") - assert(transformed.children.head.tags(TreeNodeTagName("test")) == "child") + assert(transformed.getTagValue(tag) == Some("parent")) + assert(transformed.children.head.getTagValue(tag) == Some("child")) val transformed2 = differentTypeTransform(node) // Both the child and parent keep the tags, even if we transform the node to a new one of // different type. - assert(transformed2.tags(TreeNodeTagName("test")) == "parent") - assert(transformed2.children.head.tags.contains(TreeNodeTagName("test"))) + assert(transformed2.getTagValue(tag) == Some("parent")) + assert(transformed2.children.head.getTagValue(tag) == Some("child")) } withClue("transformDown") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 967d5f640143..d83aa78fe4a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -33,15 +33,16 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _} import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.trees.TreeNodeTagName +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.DataType object SparkPlan { // a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag // when converting a logical plan to a physical plan. - val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan") + val LOGICAL_PLAN_TAG = TreeNodeTag[LogicalPlan]("logical_plan") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 44b8db6645b6..0275942a18f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -68,7 +68,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ReturnAnswer(rootPlan) => rootPlan case _ => plan } - p.tags += SparkPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan + p.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, logicalPlan) p } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala index ca7ced5ef538..b35348b4ea3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import org.apache.spark.sql.TPCDSQuerySuite import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final} -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} @@ -81,12 +80,12 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite { // The exchange related nodes are created after the planning, they don't have corresponding // logical plan. case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec => - assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME)) + assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty) // The subquery exec nodes are just wrappers of the actual nodes, they don't have // corresponding logical plan. case _: SubqueryExec | _: ReusedSubqueryExec => - assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME)) + assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty) case _ if isScanPlanTree(plan) => // The strategies for planning scan can remove or add FilterExec/ProjectExec nodes, @@ -120,9 +119,9 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite { } private def getLogicalPlan(node: SparkPlan): LogicalPlan = { - assert(node.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME), - node.getClass.getSimpleName + " does not have a logical plan link") - node.tags(SparkPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan] + node.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).getOrElse { + fail(node.getClass.getSimpleName + " does not have a logical plan link") + } } private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: SparkPlan): Unit = { From 7cfac94d89477e3833e791138ae4898188ab94a2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 31 Jan 2020 18:18:28 +0800 Subject: [PATCH 3/5] revert unnecessary changes --- .../spark/sql/execution/SparkPlan.scala | 10 +- .../spark/sql/execution/SparkStrategies.scala | 11 -- .../LogicalPlanTagInSparkPlanSuite.scala | 132 ------------------ 3 files changed, 2 insertions(+), 151 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index d83aa78fe4a1..7646f9613efb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import scala.collection.mutable.ArrayBuffer +import scala.concurrent.ExecutionContext import org.codehaus.commons.compiler.CompileException import org.codehaus.janino.InternalCompilerException @@ -33,17 +34,10 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.DataType - -object SparkPlan { - // a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag - // when converting a logical plan to a physical plan. - val LOGICAL_PLAN_TAG = TreeNodeTag[LogicalPlan]("logical_plan") -} +import org.apache.spark.util.ThreadUtils /** * The base class for physical operators. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0275942a18f7..dbc6db62bd82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -62,17 +62,6 @@ case class PlanLater(plan: LogicalPlan) extends LeafExecNode { abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => - override def plan(plan: LogicalPlan): Iterator[SparkPlan] = { - super.plan(plan).map { p => - val logicalPlan = plan match { - case ReturnAnswer(rootPlan) => rootPlan - case _ => plan - } - p.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, logicalPlan) - p - } - } - /** * Plans special cases of limit operators. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala deleted file mode 100644 index b35348b4ea3b..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import scala.reflect.ClassTag - -import org.apache.spark.sql.TPCDSQuerySuite -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window} -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} -import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.execution.window.WindowExec - -class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite { - - override protected def checkGeneratedCode(plan: SparkPlan): Unit = { - super.checkGeneratedCode(plan) - checkLogicalPlanTag(plan) - } - - private def isFinalAgg(aggExprs: Seq[AggregateExpression]): Boolean = { - // TODO: aggregate node without aggregate expressions can also be a final aggregate, but - // currently the aggregate node doesn't have a final/partial flag. - aggExprs.nonEmpty && aggExprs.forall(ae => ae.mode == Complete || ae.mode == Final) - } - - // A scan plan tree is a plan tree that has a leaf node under zero or more Project/Filter nodes. - private def isScanPlanTree(plan: SparkPlan): Boolean = plan match { - case p: ProjectExec => isScanPlanTree(p.child) - case f: FilterExec => isScanPlanTree(f.child) - case _: LeafExecNode => true - case _ => false - } - - private def checkLogicalPlanTag(plan: SparkPlan): Unit = { - plan match { - case _: HashJoin | _: BroadcastNestedLoopJoinExec | _: CartesianProductExec - | _: ShuffledHashJoinExec | _: SortMergeJoinExec => - assertLogicalPlanType[Join](plan) - - // There is no corresponding logical plan for the physical partial aggregate. - case agg: HashAggregateExec if isFinalAgg(agg.aggregateExpressions) => - assertLogicalPlanType[Aggregate](plan) - case agg: ObjectHashAggregateExec if isFinalAgg(agg.aggregateExpressions) => - assertLogicalPlanType[Aggregate](plan) - case agg: SortAggregateExec if isFinalAgg(agg.aggregateExpressions) => - assertLogicalPlanType[Aggregate](plan) - - case _: WindowExec => - assertLogicalPlanType[Window](plan) - - case _: UnionExec => - assertLogicalPlanType[Union](plan) - - case _: SampleExec => - assertLogicalPlanType[Sample](plan) - - case _: GenerateExec => - assertLogicalPlanType[Generate](plan) - - // The exchange related nodes are created after the planning, they don't have corresponding - // logical plan. - case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec => - assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty) - - // The subquery exec nodes are just wrappers of the actual nodes, they don't have - // corresponding logical plan. - case _: SubqueryExec | _: ReusedSubqueryExec => - assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty) - - case _ if isScanPlanTree(plan) => - // The strategies for planning scan can remove or add FilterExec/ProjectExec nodes, - // so it's not simple to check. Instead, we only check that the origin LogicalPlan - // contains the corresponding leaf node of the SparkPlan. - // a strategy might remove the filter if it's totally pushed down, e.g.: - // logical = Project(Filter(Scan A)) - // physical = ProjectExec(ScanExec A) - // we only check that leaf modes match between logical and physical plan. - val logicalLeaves = getLogicalPlan(plan).collectLeaves() - val physicalLeaves = plan.collectLeaves() - assert(logicalLeaves.length == 1) - assert(physicalLeaves.length == 1) - physicalLeaves.head match { - case _: RangeExec => logicalLeaves.head.isInstanceOf[Range] - case _: DataSourceScanExec => logicalLeaves.head.isInstanceOf[LogicalRelation] - case _: InMemoryTableScanExec => logicalLeaves.head.isInstanceOf[InMemoryRelation] - case _: LocalTableScanExec => logicalLeaves.head.isInstanceOf[LocalRelation] - case _: ExternalRDDScanExec[_] => logicalLeaves.head.isInstanceOf[ExternalRDD[_]] - case _: BatchScanExec => logicalLeaves.head.isInstanceOf[DataSourceV2Relation] - case _ => - } - // Do not need to check the children recursively. - return - - case _ => - } - - plan.children.foreach(checkLogicalPlanTag) - plan.subqueries.foreach(checkLogicalPlanTag) - } - - private def getLogicalPlan(node: SparkPlan): LogicalPlan = { - node.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).getOrElse { - fail(node.getClass.getSimpleName + " does not have a logical plan link") - } - } - - private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: SparkPlan): Unit = { - val logicalPlan = getLogicalPlan(node) - val expectedCls = implicitly[ClassTag[T]].runtimeClass - assert(expectedCls == logicalPlan.getClass) - } -} From 11b07e0f9da413cb5fb18f269fcfe6ed7f5204c4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 6 Aug 2019 10:06:36 +0800 Subject: [PATCH 4/5] [SPARK-28344][SQL] detect ambiguous self-join and fail the query MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is an alternative solution of https://github.com/apache/spark/pull/24442 . It fails the query if ambiguous self join is detected, instead of trying to disambiguate it. The problem is that, it's hard to come up with a reasonable rule to disambiguate, the rule proposed by #24442 is mostly a heuristic. This is a long-standing bug and I've seen many people complaining about it in JIRA/dev list. A typical example: ``` val df1 = … val df2 = df1.filter(...) df1.join(df2, df1("a") > df2("a")) // returns empty result ``` The root cause is, `Dataset.apply` is so powerful that users think it returns a column reference which can point to the column of the Dataset at anywhere. This is not true in many cases. `Dataset.apply` returns an `AttributeReference` . Different Datasets may share the same `AttributeReference`. In the example above, `df2` adds a Filter operator above the logical plan of `df1`, and the Filter operator reserves the output `AttributeReference` of its child. This means, `df1("a")` is exactly the same as `df2("a")`, and `df1("a") > df2("a")` always evaluates to false. We can reuse the infra in #24442 : 1. each Dataset has a globally unique id. 2. the `AttributeReference` returned by `Dataset.apply` carries the ID and column position(e.g. 3rd column of the Dataset) via metadata. 3. the logical plan of a `Dataset` carries the ID via `TreeNodeTag` When self-join happens, the analyzer asks the right side plan of join to re-generate output attributes with new exprIds. Based on it, a simple rule to detect ambiguous self join is: 1. find all column references (i.e. `AttributeReference`s with Dataset ID and col position) in the root node of a query plan. 2. for each column reference, traverse the query plan tree, find a sub-plan that carries Dataset ID and the ID is the same as the one in the column reference. 3. get the corresponding output attribute of the sub-plan by the col position in the column reference. 4. if the corresponding output attribute has a different exprID than the column reference, then it means this sub-plan is on the right side of a self-join and has regenerated its output attributes. This is an ambiguous self join because the column reference points to a table being self-joined. existing tests and new test cases Closes #25107 from cloud-fan/new-self-join. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- docs/sql-migration-guide-upgrade.md | 6 +- .../sql/catalyst/analysis/Analyzer.scala | 5 +- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../scala/org/apache/spark/sql/Column.scala | 19 +- .../scala/org/apache/spark/sql/Dataset.scala | 40 +++- .../analysis/DetectAmbiguousSelfJoin.scala | 162 ++++++++++++++ .../internal/BaseSessionStateBuilder.scala | 4 +- .../spark/sql/DataFrameAggregateSuite.scala | 15 ++ .../apache/spark/sql/DataFrameJoinSuite.scala | 51 ----- .../spark/sql/DataFrameSelfJoinSuite.scala | 205 ++++++++++++++++++ .../sql/hive/HiveSessionStateBuilder.scala | 4 +- .../sql/hive/HiveDataFrameJoinSuite.scala | 38 ---- 12 files changed, 455 insertions(+), 101 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 817063770f68..10220ae120c4 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -16,11 +16,13 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 2.4.5, `TRUNCATE TABLE` command tries to set back original permission and ACLs during re-creating the table/partition paths. To restore the behaviour of earlier versions, set `spark.sql.truncateTable.ignorePermissionAcl.enabled` to `true`. - - Since Spark 2.4.5, `spark.sql.legacy.mssqlserver.numericMapping.enabled` configuration is added in order to support the legacy MsSQLServer dialect mapping behavior using IntegerType and DoubleType for SMALLINT and REAL JDBC types, respectively. To restore the behaviour of 2.4.3 and earlier versions, set `spark.sql.legacy.mssqlserver.numericMapping.enabled` to `true`. + - Since Spark 2.4.5, `spark.sql.legacy.mssqlserver.numericMapping.enabled` configuration is added in order to support the legacy MsSQLServer dialect mapping behavior using IntegerType and DoubleType for SMALLINT and REAL JDBC types, respectively. To restore the behaviour of 2.4.3 and earlier versions, set `spark.sql.legacy.mssqlserver.numericMapping.enabled` to `true`. + + - Since Spark 2.4.5, Dataset query fails if it contains ambiguous column reference that is caused by self join. A typical example: `val df1 = ...; val df2 = df1.filter(...);`, then `df1.join(df2, df1("a") > df2("a"))` returns an empty result which is quite confusing. This is because Spark cannot resolve Dataset column references that point to tables being self joined, and `df1("a")` is exactly the same as `df2("a")` in Spark. To restore the behavior before Spark 3.0, you can set `spark.sql.analyzer.failAmbiguousSelfJoin` to `false`. ## Upgrading from Spark SQL 2.4.3 to 2.4.4 - - Since Spark 2.4.4, according to [MsSqlServer Guide](https://docs.microsoft.com/en-us/sql/connect/jdbc/using-basic-data-types?view=sql-server-2017), MsSQLServer JDBC Dialect uses ShortType and FloatType for SMALLINT and REAL, respectively. Previously, IntegerType and DoubleType is used. + - Since Spark 2.4.4, according to [MsSqlServer Guide](https://docs.microsoft.com/en-us/sql/connect/jdbc/using-basic-data-types?view=sql-server-2017), MsSQLServer JDBC Dialect uses ShortType and FloatType for SMALLINT and REAL, respectively. Previously, IntegerType and DoubleType is used. ## Upgrading from Spark SQL 2.4 to 2.4.1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f8b95139cab3..9030de70f740 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -334,7 +334,8 @@ class Analyzer( gid: Expression): Expression = { expr transform { case e: GroupingID => - if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) { + if (e.groupByExprs.isEmpty || + e.groupByExprs.map(_.canonicalized) == groupByExprs.map(_.canonicalized)) { Alias(gid, toPrettySQL(e))() } else { throw new AnalysisException( @@ -936,6 +937,8 @@ class Analyzer( // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) + // intersect/except will be rewritten to join at the begininng of optimizer. Here we need to + // deduplicate the right side plan, so that we won't produce an invalid self-join later. case i @ Intersect(left, right, _) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) case e @ Except(left, right, _) if !e.duplicateResolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 760a9db8bead..3a471b683df9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -726,6 +726,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val FAIL_AMBIGUOUS_SELF_JOIN = + buildConf("spark.sql.analyzer.failAmbiguousSelfJoin") + .doc("When true, fail the Dataset query if it contains ambiguous self-join.") + .internal() + .booleanConf + .createWithDefault(true) + // Whether to retain group by columns or not in GroupedData.agg. val DATAFRAME_RETAIN_GROUP_COLUMNS = buildConf("spark.sql.retainGroupColumns") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index a046127c3edb..bcb7cdac1327 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -48,6 +48,15 @@ private[sql] object Column { case expr => toPrettySQL(expr) } } + + private[sql] def stripColumnReferenceMetadata(a: AttributeReference): AttributeReference = { + val metadataWithoutId = new MetadataBuilder() + .withMetadata(a.metadata) + .remove(Dataset.DATASET_ID_KEY) + .remove(Dataset.COL_POS_KEY) + .build() + a.withMetadata(metadataWithoutId) + } } /** @@ -141,11 +150,15 @@ class Column(val expr: Expression) extends Logging { override def toString: String = toPrettySQL(expr) override def equals(that: Any): Boolean = that match { - case that: Column => that.expr.equals(this.expr) + case that: Column => that.normalizedExpr() == this.normalizedExpr() case _ => false } - override def hashCode: Int = this.expr.hashCode() + override def hashCode: Int = this.normalizedExpr().hashCode() + + private def normalizedExpr(): Expression = expr transform { + case a: AttributeReference => Column.stripColumnReferenceMetadata(a) + } /** Creates a column based on the given expression. */ private def withExpr(newExpr: Expression): Column = new Column(newExpr) @@ -1023,7 +1036,7 @@ class Column(val expr: Expression) extends Logging { * @since 2.0.0 */ def name(alias: String): Column = withExpr { - expr match { + normalizedExpr() match { case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata)) case other => Alias(other, alias)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c90b2e857e66..63e2433576b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -45,12 +45,14 @@ import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -60,6 +62,11 @@ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils private[sql] object Dataset { + val curId = new java.util.concurrent.atomic.AtomicLong() + val DATASET_ID_KEY = "__dataset_id" + val COL_POS_KEY = "__col_position" + val DATASET_ID_TAG = TreeNodeTag[Long]("dataset_id") + def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) // Eagerly bind the encoder so we verify that the encoder matches the underlying @@ -173,6 +180,9 @@ class Dataset[T] private[sql]( encoder: Encoder[T]) extends Serializable { + // A globally unique id of this Dataset. + private val id = Dataset.curId.getAndIncrement() + queryExecution.assertAnalyzed() // Note for Spark contributors: if adding or updating any action in `Dataset`, please make sure @@ -189,7 +199,7 @@ class Dataset[T] private[sql]( @transient private[sql] val logicalPlan: LogicalPlan = { // For various commands (like DDL) and queries with side effects, we force query execution // to happen right away to let these side effects take place eagerly. - queryExecution.analyzed match { + val plan = queryExecution.analyzed match { case c: Command => LocalRelation(c.output, withAction("command", queryExecution)(_.executeCollect())) case u @ Union(children) if children.forall(_.isInstanceOf[Command]) => @@ -197,6 +207,10 @@ class Dataset[T] private[sql]( case _ => queryExecution.analyzed } + if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN)) { + plan.setTagValue(Dataset.DATASET_ID_TAG, id) + } + plan } /** @@ -1271,11 +1285,29 @@ class Dataset[T] private[sql]( if (sqlContext.conf.supportQuotedRegexColumnName) { colRegex(colName) } else { - val expr = resolve(colName) - Column(expr) + Column(addDataFrameIdToCol(resolve(colName))) } } + // Attach the dataset id and column position to the column reference, so that we can detect + // ambiguous self-join correctly. See the rule `DetectAmbiguousSelfJoin`. + // This must be called before we return a `Column` that contains `AttributeReference`. + // Note that, the metadata added here are only avaiable in the analyzer, as the analyzer rule + // `DetectAmbiguousSelfJoin` will remove it. + private def addDataFrameIdToCol(expr: NamedExpression): NamedExpression = { + val newExpr = expr transform { + case a: AttributeReference + if sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN) => + val metadata = new MetadataBuilder() + .withMetadata(a.metadata) + .putLong(Dataset.DATASET_ID_KEY, id) + .putLong(Dataset.COL_POS_KEY, logicalPlan.output.indexWhere(a.semanticEquals)) + .build() + a.withMetadata(metadata) + } + newExpr.asInstanceOf[NamedExpression] + } + /** * Selects column based on the column name specified as a regex and returns it as [[Column]]. * @group untypedrel @@ -1289,7 +1321,7 @@ class Dataset[T] private[sql]( case ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex) => Column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive)) case _ => - Column(resolve(colName)) + Column(addDataFrameIdToCol(resolve(colName))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala new file mode 100644 index 000000000000..5c3c735f0346 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.analysis + +import scala.collection.mutable + +import org.apache.spark.sql.{AnalysisException, Column, Dataset} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Equality, Expression, ExprId} +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf + +/** + * Detects ambiguous self-joins, so that we can fail the query instead of returning confusing + * results. + * + * Dataset column reference is simply an [[AttributeReference]] that is returned by `Dataset#col`. + * Most of time we don't need to do anything special, as [[AttributeReference]] can point to + * the column precisely. However, in case of self-join, the analyzer generates + * [[AttributeReference]] with new expr IDs for the right side plan of the join. If the Dataset + * column reference points to a column in the right side plan of a self-join, users will get + * unexpected result because the column reference can't match the newly generated + * [[AttributeReference]]. + * + * Note that, this rule removes all the Dataset id related metadata from `AttributeReference`, so + * that they don't exist after analyzer. + */ +class DetectAmbiguousSelfJoin(conf: SQLConf) extends Rule[LogicalPlan] { + + // Dataset column reference is an `AttributeReference` with 2 special metadata. + private def isColumnReference(a: AttributeReference): Boolean = { + a.metadata.contains(Dataset.DATASET_ID_KEY) && a.metadata.contains(Dataset.COL_POS_KEY) + } + + private case class ColumnReference(datasetId: Long, colPos: Int, exprId: ExprId) + + private def toColumnReference(a: AttributeReference): ColumnReference = { + ColumnReference( + a.metadata.getLong(Dataset.DATASET_ID_KEY), + a.metadata.getLong(Dataset.COL_POS_KEY).toInt, + a.exprId) + } + + object LogicalPlanWithDatasetId { + def unapply(p: LogicalPlan): Option[(LogicalPlan, Long)] = { + p.getTagValue(Dataset.DATASET_ID_TAG).map(id => p -> id) + } + } + + object AttrWithCast { + def unapply(expr: Expression): Option[AttributeReference] = expr match { + case Cast(child, _, _) => unapply(child) + case a: AttributeReference => Some(a) + case _ => None + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN)) return plan + + // We always remove the special metadata from `AttributeReference` at the end of this rule, so + // Dataset column reference only exists in the root node via Dataset transformations like + // `Dataset#select`. + val colRefAttrs = plan.expressions.flatMap(_.collect { + case a: AttributeReference if isColumnReference(a) => a + }) + + if (colRefAttrs.nonEmpty) { + val colRefs = colRefAttrs.map(toColumnReference).distinct + val ambiguousColRefs = mutable.HashSet.empty[ColumnReference] + val dsIdSet = colRefs.map(_.datasetId).toSet + + plan.foreach { + case LogicalPlanWithDatasetId(p, id) if dsIdSet.contains(id) => + colRefs.foreach { ref => + if (id == ref.datasetId) { + if (ref.colPos < 0 || ref.colPos >= p.output.length) { + throw new IllegalStateException("[BUG] Hit an invalid Dataset column reference: " + + s"$ref. Please open a JIRA ticket to report it.") + } else { + // When self-join happens, the analyzer asks the right side plan to generate + // attributes with new exprIds. If a plan of a Dataset outputs an attribute which + // is referred by a column reference, and this attribute has different exprId than + // the attribute of column reference, then the column reference is ambiguous, as it + // refers to a column that gets regenerated by self-join. + val actualAttr = p.output(ref.colPos).asInstanceOf[AttributeReference] + if (actualAttr.exprId != ref.exprId) { + ambiguousColRefs += ref + } + } + } + } + + case _ => + } + + val ambiguousAttrs: Seq[AttributeReference] = plan match { + case Join( + LogicalPlanWithDatasetId(_, leftId), + LogicalPlanWithDatasetId(_, rightId), + _, condition, _) => + // If we are dealing with root join node, we need to take care of SPARK-6231: + // 1. We can de-ambiguous `df("col") === df("col")` in the join condition. + // 2. There is no ambiguity in direct self join like + // `df.join(df, df("col") === 1)`, because it doesn't matter which side the + // column comes from. + def getAmbiguousAttrs(expr: Expression): Seq[AttributeReference] = expr match { + case Equality(AttrWithCast(a), AttrWithCast(b)) if a.sameRef(b) => + Nil + case Equality(AttrWithCast(a), b) if leftId == rightId && b.foldable => + Nil + case Equality(a, AttrWithCast(b)) if leftId == rightId && a.foldable => + Nil + case a: AttributeReference => + if (isColumnReference(a)) { + val colRef = toColumnReference(a) + if (ambiguousColRefs.contains(colRef)) Seq(a) else Nil + } else { + Nil + } + case _ => expr.children.flatMap(getAmbiguousAttrs) + } + condition.toSeq.flatMap(getAmbiguousAttrs) + + case _ => ambiguousColRefs.toSeq.map { ref => + colRefAttrs.find(attr => toColumnReference(attr) == ref).get + } + } + + if (ambiguousAttrs.nonEmpty) { + throw new AnalysisException(s"Column ${ambiguousAttrs.mkString(", ")} are ambiguous. " + + "It's probably because you joined several Datasets together, and some of these " + + "Datasets are the same. This column points to one of the Datasets but Spark is unable " + + "to figure out which one. Please alias the Datasets with different names via " + + "`Dataset.as` before joining them, and specify the column using qualified name, e.g. " + + """`df.as("a").join(df.as("b"), $"a.id" > $"b.id")`. You can also set """ + + s"${SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key} to false to disable this check.") + } + } + + plan.transformExpressions { + case a: AttributeReference if isColumnReference(a) => + // Remove the special metadata from this `AttributeReference`, as the detection is done. + Column.stripColumnReferenceMetadata(a) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 9c1a15c46acd..b83d97a32316 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} +import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager @@ -161,7 +162,8 @@ abstract class BaseSessionStateBuilder( customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - PreprocessTableCreation(session) +: + new DetectAmbiguousSelfJoin(conf) +: + PreprocessTableCreation(session) +: PreprocessTableInsertion(conf) +: DataSourceAnalysis(conf) +: customPostHocResolutionRules diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d0106c44b7db..0b53b49a6b01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -167,6 +167,21 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null, null, 1, 1, 3) :: Nil ) + // use column reference in `grouping_id` instead of column name + checkAnswer( + courseSales.cube("course", "year") + .agg(grouping_id(courseSales("course"), courseSales("year"))), + Row("Java", 2012, 0) :: + Row("Java", 2013, 0) :: + Row("Java", null, 1) :: + Row("dotNET", 2012, 0) :: + Row("dotNET", 2013, 0) :: + Row("dotNET", null, 1) :: + Row(null, 2012, 2) :: + Row(null, 2013, 2) :: + Row(null, null, 3) :: Nil + ) + intercept[AnalysisException] { courseSales.groupBy().agg(grouping("course")).explain() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e6b30f9956da..61bf91a27253 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -86,25 +86,6 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, "1", 2) :: Nil) } - test("join - join using self join") { - val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") - - // self join - checkAnswer( - df.join(df, "int"), - Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: Nil) - } - - test("join - self join") { - val df1 = testData.select(testData("key")).as('df1) - val df2 = testData.select(testData("key")).as('df2) - - checkAnswer( - df1.join(df2, $"df1.key" === $"df2.key"), - sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") - .collect().toSeq) - } - test("join - cross join") { val df1 = Seq((1, "1"), (3, "3")).toDF("int", "str") val df2 = Seq((2, "2"), (4, "4")).toDF("int", "str") @@ -120,38 +101,6 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(4, "4", 1, "1") :: Row(4, "4", 3, "3") :: Nil) } - test("join - using aliases after self join") { - val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") - checkAnswer( - df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), - Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) - - checkAnswer( - df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").count(), - Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) - } - - test("[SPARK-6231] join - self join auto resolve ambiguity") { - val df = Seq((1, "1"), (2, "2")).toDF("key", "value") - checkAnswer( - df.join(df, df("key") === df("key")), - Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) - - checkAnswer( - df.join(df.filter($"value" === "2"), df("key") === df("key")), - Row(2, "2", 2, "2") :: Nil) - - checkAnswer( - df.join(df, df("key") === df("key") && df("value") === 1), - Row(1, "1", 1, "1") :: Nil) - - val left = df.groupBy("key").agg(count("*")) - val right = df.groupBy("key").agg(sum("key")) - checkAnswer( - left.join(right, left("key") === right("key")), - Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) - } - test("broadcast join hint using broadcast function") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala new file mode 100644 index 000000000000..92f1e4306c5b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions.{count, sum} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("join - join using self join") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + + // self join + checkAnswer( + df.join(df, "int"), + Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: Nil) + } + + test("join - self join") { + val df1 = testData.select(testData("key")).as('df1) + val df2 = testData.select(testData("key")).as('df2) + + checkAnswer( + df1.join(df2, $"df1.key" === $"df2.key"), + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + .collect().toSeq) + } + + test("join - self join auto resolve ambiguity with case insensitivity") { + val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.join(df, df("key") === df("Key")), + Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) + + checkAnswer( + df.join(df.filter($"value" === "2"), df("key") === df("Key")), + Row(2, "2", 2, "2") :: Nil) + } + + test("join - using aliases after self join") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + } + + test("[SPARK-6231] join - self join auto resolve ambiguity") { + val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.join(df, df("key") === df("key")), + Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) + + checkAnswer( + df.join(df.filter($"value" === "2"), df("key") === df("key")), + Row(2, "2", 2, "2") :: Nil) + + checkAnswer( + df.join(df, df("key") === df("key") && df("value") === 1), + Row(1, "1", 1, "1") :: Nil) + + val left = df.groupBy("key").agg(count("*")) + val right = df.groupBy("key").agg(sum("key")) + checkAnswer( + left.join(right, left("key") === right("key")), + Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) + } + + private def assertAmbiguousSelfJoin(df: => DataFrame): Unit = { + val e = intercept[AnalysisException](df) + assert(e.message.contains("ambiguous")) + } + + test("SPARK-28344: fail ambiguous self join - column ref in join condition") { + val df1 = spark.range(3) + val df2 = df1.filter($"id" > 0) + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "false", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // `df1("id") > df2("id")` is always false. + checkAnswer(df1.join(df2, df1("id") > df2("id")), Nil) + + // Alias the dataframe and use qualified column names can fix ambiguous self-join. + val aliasedDf1 = df1.alias("left") + val aliasedDf2 = df2.as("right") + checkAnswer( + aliasedDf1.join(aliasedDf2, $"left.id" > $"right.id"), + Seq(Row(2, 1))) + } + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "true", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertAmbiguousSelfJoin(df1.join(df2, df1("id") > df2("id"))) + } + } + + test("SPARK-28344: fail ambiguous self join - Dataset.colRegex as column ref") { + val df1 = spark.range(3) + val df2 = df1.filter($"id" > 0) + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "true", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertAmbiguousSelfJoin(df1.join(df2, df1.colRegex("id") > df2.colRegex("id"))) + } + } + + test("SPARK-28344: fail ambiguous self join - Dataset.col with nested field") { + val df1 = spark.read.json(Seq("""{"a": {"b": 1, "c": 1}}""").toDS()) + val df2 = df1.filter($"a.b" > 0) + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "true", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertAmbiguousSelfJoin(df1.join(df2, df1("a.b") > df2("a.c"))) + } + } + + test("SPARK-28344: fail ambiguous self join - column ref in Project") { + val df1 = spark.range(3) + val df2 = df1.filter($"id" > 0) + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "false", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // `df2("id")` actually points to the column of `df1`. + checkAnswer(df1.join(df2).select(df2("id")), Seq(0, 0, 1, 1, 2, 2).map(Row(_))) + + // Alias the dataframe and use qualified column names can fix ambiguous self-join. + val aliasedDf1 = df1.alias("left") + val aliasedDf2 = df2.as("right") + checkAnswer( + aliasedDf1.join(aliasedDf2).select($"right.id"), + Seq(1, 1, 1, 2, 2, 2).map(Row(_))) + } + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "true", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertAmbiguousSelfJoin(df1.join(df2).select(df2("id"))) + } + } + + test("SPARK-28344: fail ambiguous self join - join three tables") { + val df1 = spark.range(3) + val df2 = df1.filter($"id" > 0) + val df3 = df1.filter($"id" <= 2) + val df4 = spark.range(1) + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "false", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // `df2("id") < df3("id")` is always false + checkAnswer(df1.join(df2).join(df3, df2("id") < df3("id")), Nil) + // `df2("id")` actually points to the column of `df1`. + checkAnswer( + df1.join(df4).join(df2).select(df2("id")), + Seq(0, 0, 1, 1, 2, 2).map(Row(_))) + // `df4("id")` is not ambiguous. + checkAnswer( + df1.join(df4).join(df2).select(df4("id")), + Seq(0, 0, 0, 0, 0, 0).map(Row(_))) + + // Alias the dataframe and use qualified column names can fix ambiguous self-join. + val aliasedDf1 = df1.alias("x") + val aliasedDf2 = df2.as("y") + val aliasedDf3 = df3.as("z") + checkAnswer( + aliasedDf1.join(aliasedDf2).join(aliasedDf3, $"y.id" < $"z.id"), + Seq(Row(0, 1, 2), Row(1, 1, 2), Row(2, 1, 2))) + checkAnswer( + aliasedDf1.join(df4).join(aliasedDf2).select($"y.id"), + Seq(1, 1, 1, 2, 2, 2).map(Row(_))) + } + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "true", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertAmbiguousSelfJoin(df1.join(df2).join(df3, df2("id") < df3("id"))) + assertAmbiguousSelfJoin(df1.join(df4).join(df2).select(df2("id"))) + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 2882672f327c..9d43ee1c12e3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner +import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} @@ -74,7 +75,8 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - new DetermineTableStats(session) +: + new DetectAmbiguousSelfJoin(conf) +: + new DetermineTableStats(session) +: RelationConversions(conf, catalog) +: PreprocessTableCreation(session) +: PreprocessTableInsertion(conf) +: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala deleted file mode 100644 index cdc259d75b13..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.TestHiveSingleton - -class HiveDataFrameJoinSuite extends QueryTest with TestHiveSingleton { - import spark.implicits._ - - // We should move this into SQL package if we make case sensitivity configurable in SQL. - test("join - self join auto resolve ambiguity with case insensitivity") { - val df = Seq((1, "1"), (2, "2")).toDF("key", "value") - checkAnswer( - df.join(df, df("key") === df("Key")), - Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) - - checkAnswer( - df.join(df.filter($"value" === "2"), df("key") === df("Key")), - Row(2, "2", 2, "2") :: Nil) - } - -} From 8199e30545758bcae0e8deaee06de430a09a0d76 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 31 Jan 2020 20:07:18 +0800 Subject: [PATCH 5/5] fix --- .../spark/sql/catalyst/expressions/namedExpressions.scala | 2 +- .../spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 049ea7769139..142c56b361d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -308,7 +308,7 @@ case class AttributeReference( } } - override def withMetadata(newMetadata: Metadata): Attribute = { + override def withMetadata(newMetadata: Metadata): AttributeReference = { AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala index 5c3c735f0346..a0eeadd69bef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala @@ -113,7 +113,7 @@ class DetectAmbiguousSelfJoin(conf: SQLConf) extends Rule[LogicalPlan] { case Join( LogicalPlanWithDatasetId(_, leftId), LogicalPlanWithDatasetId(_, rightId), - _, condition, _) => + _, condition) => // If we are dealing with root join node, we need to take care of SPARK-6231: // 1. We can de-ambiguous `df("col") === df("col")` in the join condition. // 2. There is no ambiguity in direct self join like