diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 817063770f68..10220ae120c4 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -16,11 +16,13 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 2.4.5, `TRUNCATE TABLE` command tries to set back original permission and ACLs during re-creating the table/partition paths. To restore the behaviour of earlier versions, set `spark.sql.truncateTable.ignorePermissionAcl.enabled` to `true`. - - Since Spark 2.4.5, `spark.sql.legacy.mssqlserver.numericMapping.enabled` configuration is added in order to support the legacy MsSQLServer dialect mapping behavior using IntegerType and DoubleType for SMALLINT and REAL JDBC types, respectively. To restore the behaviour of 2.4.3 and earlier versions, set `spark.sql.legacy.mssqlserver.numericMapping.enabled` to `true`. + - Since Spark 2.4.5, `spark.sql.legacy.mssqlserver.numericMapping.enabled` configuration is added in order to support the legacy MsSQLServer dialect mapping behavior using IntegerType and DoubleType for SMALLINT and REAL JDBC types, respectively. To restore the behaviour of 2.4.3 and earlier versions, set `spark.sql.legacy.mssqlserver.numericMapping.enabled` to `true`. + + - Since Spark 2.4.5, Dataset query fails if it contains ambiguous column reference that is caused by self join. A typical example: `val df1 = ...; val df2 = df1.filter(...);`, then `df1.join(df2, df1("a") > df2("a"))` returns an empty result which is quite confusing. This is because Spark cannot resolve Dataset column references that point to tables being self joined, and `df1("a")` is exactly the same as `df2("a")` in Spark. To restore the behavior before Spark 3.0, you can set `spark.sql.analyzer.failAmbiguousSelfJoin` to `false`. ## Upgrading from Spark SQL 2.4.3 to 2.4.4 - - Since Spark 2.4.4, according to [MsSqlServer Guide](https://docs.microsoft.com/en-us/sql/connect/jdbc/using-basic-data-types?view=sql-server-2017), MsSQLServer JDBC Dialect uses ShortType and FloatType for SMALLINT and REAL, respectively. Previously, IntegerType and DoubleType is used. + - Since Spark 2.4.4, according to [MsSqlServer Guide](https://docs.microsoft.com/en-us/sql/connect/jdbc/using-basic-data-types?view=sql-server-2017), MsSQLServer JDBC Dialect uses ShortType and FloatType for SMALLINT and REAL, respectively. Previously, IntegerType and DoubleType is used. ## Upgrading from Spark SQL 2.4 to 2.4.1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f8b95139cab3..9030de70f740 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -334,7 +334,8 @@ class Analyzer( gid: Expression): Expression = { expr transform { case e: GroupingID => - if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) { + if (e.groupByExprs.isEmpty || + e.groupByExprs.map(_.canonicalized) == groupByExprs.map(_.canonicalized)) { Alias(gid, toPrettySQL(e))() } else { throw new AnalysisException( @@ -936,6 +937,8 @@ class Analyzer( // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) + // intersect/except will be rewritten to join at the begininng of optimizer. Here we need to + // deduplicate the right side plan, so that we won't produce an invalid self-join later. case i @ Intersect(left, right, _) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) case e @ Except(left, right, _) if !e.duplicateResolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 049ea7769139..142c56b361d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -308,7 +308,7 @@ case class AttributeReference( } } - override def withMetadata(newMetadata: Metadata): Attribute = { + override def withMetadata(newMetadata: Metadata): AttributeReference = { AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 1c255861ad3a..94deacee3c87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -996,7 +996,11 @@ case class OneRowRelation() extends LeafNode { override def computeStats(): Statistics = Statistics(sizeInBytes = 1) /** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */ - override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = OneRowRelation() + override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = { + val newCopy = OneRowRelation() + newCopy.copyTagsFrom(this) + newCopy + } } /** A logical plan for `dropDuplicates`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index a924f10fb366..e094bfbb3ec7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.trees import java.util.UUID -import scala.collection.Map +import scala.collection.{mutable, Map} import scala.reflect.ClassTag import org.apache.commons.lang3.ClassUtils @@ -71,6 +71,9 @@ object CurrentOrigin { } } +// A tag of a `TreeNode`, which defines name and type +case class TreeNodeTag[T](name: String) + // scalastyle:off abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { // scalastyle:on @@ -78,6 +81,24 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { val origin: Origin = CurrentOrigin.get + /** + * A mutable map for holding auxiliary information of this tree node. It will be carried over + * when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`. + */ + private val tags: mutable.Map[TreeNodeTag[_], Any] = mutable.Map.empty + + protected def copyTagsFrom(other: BaseType): Unit = { + tags ++= other.tags + } + + def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = { + tags(tag) = value + } + + def getTagValue[T](tag: TreeNodeTag[T]): Option[T] = { + tags.get(tag).map(_.asInstanceOf[T]) + } + /** * Returns a Seq of the children of this node. * Children should not change. Immutability required for containsChild optimization @@ -262,6 +283,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { if (this fastEquals afterRule) { mapChildren(_.transformDown(rule)) } else { + // If the transform function replaces this node with a new one, carry over the tags. + afterRule.tags ++= this.tags afterRule.mapChildren(_.transformDown(rule)) } } @@ -275,7 +298,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { */ def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { val afterRuleOnChildren = mapChildren(_.transformUp(rule)) - if (this fastEquals afterRuleOnChildren) { + val newNode = if (this fastEquals afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, identity[BaseType]) } @@ -284,6 +307,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) } } + // If the transform function replaces this node with a new one, carry over the tags. + newNode.tags ++= this.tags + newNode } /** @@ -402,7 +428,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { try { CurrentOrigin.withOrigin(origin) { - defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType] + val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType] + res.copyTagsFrom(this) + res } } catch { case e: java.lang.IllegalArgumentException => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 760a9db8bead..3a471b683df9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -726,6 +726,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val FAIL_AMBIGUOUS_SELF_JOIN = + buildConf("spark.sql.analyzer.failAmbiguousSelfJoin") + .doc("When true, fail the Dataset query if it contains ambiguous self-join.") + .internal() + .booleanConf + .createWithDefault(true) + // Whether to retain group by columns or not in GroupedData.agg. val DATAFRAME_RETAIN_GROUP_COLUMNS = buildConf("spark.sql.retainGroupColumns") .internal() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index e37cf8a8e217..e94610e59c7a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -617,4 +617,57 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Coalesce(Stream(Literal(1), Literal(3))) assert(result === expected) } + + test("tags will be carried over after copy & transform") { + val tag = TreeNodeTag[String]("test") + + withClue("makeCopy") { + val node = Dummy(None) + node.setTagValue(tag, "a") + val copied = node.makeCopy(Array(Some(Literal(1)))) + assert(copied.getTagValue(tag) == Some("a")) + } + + def checkTransform( + sameTypeTransform: Expression => Expression, + differentTypeTransform: Expression => Expression): Unit = { + val child = Dummy(None) + child.setTagValue(tag, "child") + val node = Dummy(Some(child)) + node.setTagValue(tag, "parent") + + val transformed = sameTypeTransform(node) + // Both the child and parent keep the tags + assert(transformed.getTagValue(tag) == Some("parent")) + assert(transformed.children.head.getTagValue(tag) == Some("child")) + + val transformed2 = differentTypeTransform(node) + // Both the child and parent keep the tags, even if we transform the node to a new one of + // different type. + assert(transformed2.getTagValue(tag) == Some("parent")) + assert(transformed2.children.head.getTagValue(tag) == Some("child")) + } + + withClue("transformDown") { + checkTransform( + sameTypeTransform = _ transformDown { + case Dummy(None) => Dummy(Some(Literal(1))) + }, + differentTypeTransform = _ transformDown { + case Dummy(None) => Literal(1) + + }) + } + + withClue("transformUp") { + checkTransform( + sameTypeTransform = _ transformUp { + case Dummy(None) => Dummy(Some(Literal(1))) + }, + differentTypeTransform = _ transformUp { + case Dummy(None) => Literal(1) + + }) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index a046127c3edb..bcb7cdac1327 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -48,6 +48,15 @@ private[sql] object Column { case expr => toPrettySQL(expr) } } + + private[sql] def stripColumnReferenceMetadata(a: AttributeReference): AttributeReference = { + val metadataWithoutId = new MetadataBuilder() + .withMetadata(a.metadata) + .remove(Dataset.DATASET_ID_KEY) + .remove(Dataset.COL_POS_KEY) + .build() + a.withMetadata(metadataWithoutId) + } } /** @@ -141,11 +150,15 @@ class Column(val expr: Expression) extends Logging { override def toString: String = toPrettySQL(expr) override def equals(that: Any): Boolean = that match { - case that: Column => that.expr.equals(this.expr) + case that: Column => that.normalizedExpr() == this.normalizedExpr() case _ => false } - override def hashCode: Int = this.expr.hashCode() + override def hashCode: Int = this.normalizedExpr().hashCode() + + private def normalizedExpr(): Expression = expr transform { + case a: AttributeReference => Column.stripColumnReferenceMetadata(a) + } /** Creates a column based on the given expression. */ private def withExpr(newExpr: Expression): Column = new Column(newExpr) @@ -1023,7 +1036,7 @@ class Column(val expr: Expression) extends Logging { * @since 2.0.0 */ def name(alias: String): Column = withExpr { - expr match { + normalizedExpr() match { case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata)) case other => Alias(other, alias)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c90b2e857e66..63e2433576b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -45,12 +45,14 @@ import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -60,6 +62,11 @@ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils private[sql] object Dataset { + val curId = new java.util.concurrent.atomic.AtomicLong() + val DATASET_ID_KEY = "__dataset_id" + val COL_POS_KEY = "__col_position" + val DATASET_ID_TAG = TreeNodeTag[Long]("dataset_id") + def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) // Eagerly bind the encoder so we verify that the encoder matches the underlying @@ -173,6 +180,9 @@ class Dataset[T] private[sql]( encoder: Encoder[T]) extends Serializable { + // A globally unique id of this Dataset. + private val id = Dataset.curId.getAndIncrement() + queryExecution.assertAnalyzed() // Note for Spark contributors: if adding or updating any action in `Dataset`, please make sure @@ -189,7 +199,7 @@ class Dataset[T] private[sql]( @transient private[sql] val logicalPlan: LogicalPlan = { // For various commands (like DDL) and queries with side effects, we force query execution // to happen right away to let these side effects take place eagerly. - queryExecution.analyzed match { + val plan = queryExecution.analyzed match { case c: Command => LocalRelation(c.output, withAction("command", queryExecution)(_.executeCollect())) case u @ Union(children) if children.forall(_.isInstanceOf[Command]) => @@ -197,6 +207,10 @@ class Dataset[T] private[sql]( case _ => queryExecution.analyzed } + if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN)) { + plan.setTagValue(Dataset.DATASET_ID_TAG, id) + } + plan } /** @@ -1271,11 +1285,29 @@ class Dataset[T] private[sql]( if (sqlContext.conf.supportQuotedRegexColumnName) { colRegex(colName) } else { - val expr = resolve(colName) - Column(expr) + Column(addDataFrameIdToCol(resolve(colName))) } } + // Attach the dataset id and column position to the column reference, so that we can detect + // ambiguous self-join correctly. See the rule `DetectAmbiguousSelfJoin`. + // This must be called before we return a `Column` that contains `AttributeReference`. + // Note that, the metadata added here are only avaiable in the analyzer, as the analyzer rule + // `DetectAmbiguousSelfJoin` will remove it. + private def addDataFrameIdToCol(expr: NamedExpression): NamedExpression = { + val newExpr = expr transform { + case a: AttributeReference + if sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN) => + val metadata = new MetadataBuilder() + .withMetadata(a.metadata) + .putLong(Dataset.DATASET_ID_KEY, id) + .putLong(Dataset.COL_POS_KEY, logicalPlan.output.indexWhere(a.semanticEquals)) + .build() + a.withMetadata(metadata) + } + newExpr.asInstanceOf[NamedExpression] + } + /** * Selects column based on the column name specified as a regex and returns it as [[Column]]. * @group untypedrel @@ -1289,7 +1321,7 @@ class Dataset[T] private[sql]( case ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex) => Column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive)) case _ => - Column(resolve(colName)) + Column(addDataFrameIdToCol(resolve(colName))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala new file mode 100644 index 000000000000..a0eeadd69bef --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.analysis + +import scala.collection.mutable + +import org.apache.spark.sql.{AnalysisException, Column, Dataset} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Equality, Expression, ExprId} +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf + +/** + * Detects ambiguous self-joins, so that we can fail the query instead of returning confusing + * results. + * + * Dataset column reference is simply an [[AttributeReference]] that is returned by `Dataset#col`. + * Most of time we don't need to do anything special, as [[AttributeReference]] can point to + * the column precisely. However, in case of self-join, the analyzer generates + * [[AttributeReference]] with new expr IDs for the right side plan of the join. If the Dataset + * column reference points to a column in the right side plan of a self-join, users will get + * unexpected result because the column reference can't match the newly generated + * [[AttributeReference]]. + * + * Note that, this rule removes all the Dataset id related metadata from `AttributeReference`, so + * that they don't exist after analyzer. + */ +class DetectAmbiguousSelfJoin(conf: SQLConf) extends Rule[LogicalPlan] { + + // Dataset column reference is an `AttributeReference` with 2 special metadata. + private def isColumnReference(a: AttributeReference): Boolean = { + a.metadata.contains(Dataset.DATASET_ID_KEY) && a.metadata.contains(Dataset.COL_POS_KEY) + } + + private case class ColumnReference(datasetId: Long, colPos: Int, exprId: ExprId) + + private def toColumnReference(a: AttributeReference): ColumnReference = { + ColumnReference( + a.metadata.getLong(Dataset.DATASET_ID_KEY), + a.metadata.getLong(Dataset.COL_POS_KEY).toInt, + a.exprId) + } + + object LogicalPlanWithDatasetId { + def unapply(p: LogicalPlan): Option[(LogicalPlan, Long)] = { + p.getTagValue(Dataset.DATASET_ID_TAG).map(id => p -> id) + } + } + + object AttrWithCast { + def unapply(expr: Expression): Option[AttributeReference] = expr match { + case Cast(child, _, _) => unapply(child) + case a: AttributeReference => Some(a) + case _ => None + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN)) return plan + + // We always remove the special metadata from `AttributeReference` at the end of this rule, so + // Dataset column reference only exists in the root node via Dataset transformations like + // `Dataset#select`. + val colRefAttrs = plan.expressions.flatMap(_.collect { + case a: AttributeReference if isColumnReference(a) => a + }) + + if (colRefAttrs.nonEmpty) { + val colRefs = colRefAttrs.map(toColumnReference).distinct + val ambiguousColRefs = mutable.HashSet.empty[ColumnReference] + val dsIdSet = colRefs.map(_.datasetId).toSet + + plan.foreach { + case LogicalPlanWithDatasetId(p, id) if dsIdSet.contains(id) => + colRefs.foreach { ref => + if (id == ref.datasetId) { + if (ref.colPos < 0 || ref.colPos >= p.output.length) { + throw new IllegalStateException("[BUG] Hit an invalid Dataset column reference: " + + s"$ref. Please open a JIRA ticket to report it.") + } else { + // When self-join happens, the analyzer asks the right side plan to generate + // attributes with new exprIds. If a plan of a Dataset outputs an attribute which + // is referred by a column reference, and this attribute has different exprId than + // the attribute of column reference, then the column reference is ambiguous, as it + // refers to a column that gets regenerated by self-join. + val actualAttr = p.output(ref.colPos).asInstanceOf[AttributeReference] + if (actualAttr.exprId != ref.exprId) { + ambiguousColRefs += ref + } + } + } + } + + case _ => + } + + val ambiguousAttrs: Seq[AttributeReference] = plan match { + case Join( + LogicalPlanWithDatasetId(_, leftId), + LogicalPlanWithDatasetId(_, rightId), + _, condition) => + // If we are dealing with root join node, we need to take care of SPARK-6231: + // 1. We can de-ambiguous `df("col") === df("col")` in the join condition. + // 2. There is no ambiguity in direct self join like + // `df.join(df, df("col") === 1)`, because it doesn't matter which side the + // column comes from. + def getAmbiguousAttrs(expr: Expression): Seq[AttributeReference] = expr match { + case Equality(AttrWithCast(a), AttrWithCast(b)) if a.sameRef(b) => + Nil + case Equality(AttrWithCast(a), b) if leftId == rightId && b.foldable => + Nil + case Equality(a, AttrWithCast(b)) if leftId == rightId && a.foldable => + Nil + case a: AttributeReference => + if (isColumnReference(a)) { + val colRef = toColumnReference(a) + if (ambiguousColRefs.contains(colRef)) Seq(a) else Nil + } else { + Nil + } + case _ => expr.children.flatMap(getAmbiguousAttrs) + } + condition.toSeq.flatMap(getAmbiguousAttrs) + + case _ => ambiguousColRefs.toSeq.map { ref => + colRefAttrs.find(attr => toColumnReference(attr) == ref).get + } + } + + if (ambiguousAttrs.nonEmpty) { + throw new AnalysisException(s"Column ${ambiguousAttrs.mkString(", ")} are ambiguous. " + + "It's probably because you joined several Datasets together, and some of these " + + "Datasets are the same. This column points to one of the Datasets but Spark is unable " + + "to figure out which one. Please alias the Datasets with different names via " + + "`Dataset.as` before joining them, and specify the column using qualified name, e.g. " + + """`df.as("a").join(df.as("b"), $"a.id" > $"b.id")`. You can also set """ + + s"${SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key} to false to disable this check.") + } + } + + plan.transformExpressions { + case a: AttributeReference if isColumnReference(a) => + // Remove the special metadata from this `AttributeReference`, as the detection is done. + Column.stripColumnReferenceMetadata(a) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 9c1a15c46acd..b83d97a32316 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} +import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager @@ -161,7 +162,8 @@ abstract class BaseSessionStateBuilder( customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - PreprocessTableCreation(session) +: + new DetectAmbiguousSelfJoin(conf) +: + PreprocessTableCreation(session) +: PreprocessTableInsertion(conf) +: DataSourceAnalysis(conf) +: customPostHocResolutionRules diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d0106c44b7db..0b53b49a6b01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -167,6 +167,21 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null, null, 1, 1, 3) :: Nil ) + // use column reference in `grouping_id` instead of column name + checkAnswer( + courseSales.cube("course", "year") + .agg(grouping_id(courseSales("course"), courseSales("year"))), + Row("Java", 2012, 0) :: + Row("Java", 2013, 0) :: + Row("Java", null, 1) :: + Row("dotNET", 2012, 0) :: + Row("dotNET", 2013, 0) :: + Row("dotNET", null, 1) :: + Row(null, 2012, 2) :: + Row(null, 2013, 2) :: + Row(null, null, 3) :: Nil + ) + intercept[AnalysisException] { courseSales.groupBy().agg(grouping("course")).explain() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e6b30f9956da..61bf91a27253 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -86,25 +86,6 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, "1", 2) :: Nil) } - test("join - join using self join") { - val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") - - // self join - checkAnswer( - df.join(df, "int"), - Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: Nil) - } - - test("join - self join") { - val df1 = testData.select(testData("key")).as('df1) - val df2 = testData.select(testData("key")).as('df2) - - checkAnswer( - df1.join(df2, $"df1.key" === $"df2.key"), - sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") - .collect().toSeq) - } - test("join - cross join") { val df1 = Seq((1, "1"), (3, "3")).toDF("int", "str") val df2 = Seq((2, "2"), (4, "4")).toDF("int", "str") @@ -120,38 +101,6 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(4, "4", 1, "1") :: Row(4, "4", 3, "3") :: Nil) } - test("join - using aliases after self join") { - val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") - checkAnswer( - df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), - Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) - - checkAnswer( - df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").count(), - Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) - } - - test("[SPARK-6231] join - self join auto resolve ambiguity") { - val df = Seq((1, "1"), (2, "2")).toDF("key", "value") - checkAnswer( - df.join(df, df("key") === df("key")), - Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) - - checkAnswer( - df.join(df.filter($"value" === "2"), df("key") === df("key")), - Row(2, "2", 2, "2") :: Nil) - - checkAnswer( - df.join(df, df("key") === df("key") && df("value") === 1), - Row(1, "1", 1, "1") :: Nil) - - val left = df.groupBy("key").agg(count("*")) - val right = df.groupBy("key").agg(sum("key")) - checkAnswer( - left.join(right, left("key") === right("key")), - Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) - } - test("broadcast join hint using broadcast function") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala new file mode 100644 index 000000000000..92f1e4306c5b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions.{count, sum} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("join - join using self join") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + + // self join + checkAnswer( + df.join(df, "int"), + Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: Nil) + } + + test("join - self join") { + val df1 = testData.select(testData("key")).as('df1) + val df2 = testData.select(testData("key")).as('df2) + + checkAnswer( + df1.join(df2, $"df1.key" === $"df2.key"), + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + .collect().toSeq) + } + + test("join - self join auto resolve ambiguity with case insensitivity") { + val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.join(df, df("key") === df("Key")), + Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) + + checkAnswer( + df.join(df.filter($"value" === "2"), df("key") === df("Key")), + Row(2, "2", 2, "2") :: Nil) + } + + test("join - using aliases after self join") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + } + + test("[SPARK-6231] join - self join auto resolve ambiguity") { + val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.join(df, df("key") === df("key")), + Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) + + checkAnswer( + df.join(df.filter($"value" === "2"), df("key") === df("key")), + Row(2, "2", 2, "2") :: Nil) + + checkAnswer( + df.join(df, df("key") === df("key") && df("value") === 1), + Row(1, "1", 1, "1") :: Nil) + + val left = df.groupBy("key").agg(count("*")) + val right = df.groupBy("key").agg(sum("key")) + checkAnswer( + left.join(right, left("key") === right("key")), + Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) + } + + private def assertAmbiguousSelfJoin(df: => DataFrame): Unit = { + val e = intercept[AnalysisException](df) + assert(e.message.contains("ambiguous")) + } + + test("SPARK-28344: fail ambiguous self join - column ref in join condition") { + val df1 = spark.range(3) + val df2 = df1.filter($"id" > 0) + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "false", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // `df1("id") > df2("id")` is always false. + checkAnswer(df1.join(df2, df1("id") > df2("id")), Nil) + + // Alias the dataframe and use qualified column names can fix ambiguous self-join. + val aliasedDf1 = df1.alias("left") + val aliasedDf2 = df2.as("right") + checkAnswer( + aliasedDf1.join(aliasedDf2, $"left.id" > $"right.id"), + Seq(Row(2, 1))) + } + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "true", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertAmbiguousSelfJoin(df1.join(df2, df1("id") > df2("id"))) + } + } + + test("SPARK-28344: fail ambiguous self join - Dataset.colRegex as column ref") { + val df1 = spark.range(3) + val df2 = df1.filter($"id" > 0) + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "true", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertAmbiguousSelfJoin(df1.join(df2, df1.colRegex("id") > df2.colRegex("id"))) + } + } + + test("SPARK-28344: fail ambiguous self join - Dataset.col with nested field") { + val df1 = spark.read.json(Seq("""{"a": {"b": 1, "c": 1}}""").toDS()) + val df2 = df1.filter($"a.b" > 0) + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "true", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertAmbiguousSelfJoin(df1.join(df2, df1("a.b") > df2("a.c"))) + } + } + + test("SPARK-28344: fail ambiguous self join - column ref in Project") { + val df1 = spark.range(3) + val df2 = df1.filter($"id" > 0) + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "false", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // `df2("id")` actually points to the column of `df1`. + checkAnswer(df1.join(df2).select(df2("id")), Seq(0, 0, 1, 1, 2, 2).map(Row(_))) + + // Alias the dataframe and use qualified column names can fix ambiguous self-join. + val aliasedDf1 = df1.alias("left") + val aliasedDf2 = df2.as("right") + checkAnswer( + aliasedDf1.join(aliasedDf2).select($"right.id"), + Seq(1, 1, 1, 2, 2, 2).map(Row(_))) + } + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "true", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertAmbiguousSelfJoin(df1.join(df2).select(df2("id"))) + } + } + + test("SPARK-28344: fail ambiguous self join - join three tables") { + val df1 = spark.range(3) + val df2 = df1.filter($"id" > 0) + val df3 = df1.filter($"id" <= 2) + val df4 = spark.range(1) + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "false", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // `df2("id") < df3("id")` is always false + checkAnswer(df1.join(df2).join(df3, df2("id") < df3("id")), Nil) + // `df2("id")` actually points to the column of `df1`. + checkAnswer( + df1.join(df4).join(df2).select(df2("id")), + Seq(0, 0, 1, 1, 2, 2).map(Row(_))) + // `df4("id")` is not ambiguous. + checkAnswer( + df1.join(df4).join(df2).select(df4("id")), + Seq(0, 0, 0, 0, 0, 0).map(Row(_))) + + // Alias the dataframe and use qualified column names can fix ambiguous self-join. + val aliasedDf1 = df1.alias("x") + val aliasedDf2 = df2.as("y") + val aliasedDf3 = df3.as("z") + checkAnswer( + aliasedDf1.join(aliasedDf2).join(aliasedDf3, $"y.id" < $"z.id"), + Seq(Row(0, 1, 2), Row(1, 1, 2), Row(2, 1, 2))) + checkAnswer( + aliasedDf1.join(df4).join(aliasedDf2).select($"y.id"), + Seq(1, 1, 1, 2, 2, 2).map(Row(_))) + } + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "true", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertAmbiguousSelfJoin(df1.join(df2).join(df3, df2("id") < df3("id"))) + assertAmbiguousSelfJoin(df1.join(df4).join(df2).select(df2("id"))) + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 2882672f327c..9d43ee1c12e3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner +import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} @@ -74,7 +75,8 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - new DetermineTableStats(session) +: + new DetectAmbiguousSelfJoin(conf) +: + new DetermineTableStats(session) +: RelationConversions(conf, catalog) +: PreprocessTableCreation(session) +: PreprocessTableInsertion(conf) +: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala deleted file mode 100644 index cdc259d75b13..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.TestHiveSingleton - -class HiveDataFrameJoinSuite extends QueryTest with TestHiveSingleton { - import spark.implicits._ - - // We should move this into SQL package if we make case sensitivity configurable in SQL. - test("join - self join auto resolve ambiguity with case insensitivity") { - val df = Seq((1, "1"), (2, "2")).toDF("key", "value") - checkAnswer( - df.join(df, df("key") === df("Key")), - Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) - - checkAnswer( - df.join(df.filter($"value" === "2"), df("key") === df("Key")), - Row(2, "2", 2, "2") :: Nil) - } - -}