diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 02d83e7e8cb6..b4b1bf89a96d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -345,7 +345,8 @@ class Analyzer( gid: Expression): Expression = { expr transform { case e: GroupingID => - if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) { + if (e.groupByExprs.isEmpty || + e.groupByExprs.map(_.canonicalized) == groupByExprs.map(_.canonicalized)) { Alias(gid, toPrettySQL(e))() } else { throw new AnalysisException( @@ -952,6 +953,8 @@ class Analyzer( // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) + // intersect/except will be rewritten to join at the begininng of optimizer. Here we need to + // deduplicate the right side plan, so that we won't produce an invalid self-join later. case i @ Intersect(left, right, _) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) case e @ Except(left, right, _) if !e.duplicateResolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 2b98132f188f..2080a1abfe66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -951,13 +951,24 @@ case class SubqueryAlias( def alias: String = name.identifier override def output: Seq[Attribute] = { - val qualifierList = name.database.map(Seq(_, alias)).getOrElse(Seq(alias)) - child.output.map(_.withQualifier(qualifierList)) + if (isHiddenAlias) { + child.output + } else { + val qualifierList = name.database.map(Seq(_, alias)).getOrElse(Seq(alias)) + child.output.map(_.withQualifier(qualifierList)) + } } + override def doCanonicalize(): LogicalPlan = child.canonicalized + + def isHiddenAlias: Boolean = { + name.database.isEmpty && name.identifier.startsWith(SubqueryAlias.HIDDEN_ALIAS_PREFIX) + } } object SubqueryAlias { + val HIDDEN_ALIAS_PREFIX = "__hidden_alias" + def apply( identifier: String, child: LogicalPlan): SubqueryAlias = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b4c68a701411..cc50d6dec403 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -759,6 +759,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val RESOLVE_DATASET_COLUMN_REFERENCE = + buildConf("spark.sql.analyzer.resolveDatasetColumnReference") + .doc("When true, resolve Dataset column reference in case of self-join.") + .internal() + .booleanConf + .createWithDefault(true) + // Whether to retain group by columns or not in GroupedData.agg. val DATAFRAME_RETAIN_GROUP_COLUMNS = buildConf("spark.sql.retainGroupColumns") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 5a408b29f933..a9d6e49d80a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -48,6 +48,15 @@ private[sql] object Column { case expr => toPrettySQL(expr) } } + + private[sql] def stripColumnReferenceMetadata(a: AttributeReference): AttributeReference = { + val metadataWithoutId = new MetadataBuilder() + .withMetadata(a.metadata) + .remove(Dataset.ID_PREFIX) + .remove(Dataset.COL_POS_PREFIX) + .build() + a.withMetadata(metadataWithoutId) + } } /** @@ -144,11 +153,16 @@ class Column(val expr: Expression) extends Logging { override def toString: String = toPrettySQL(expr) override def equals(that: Any): Boolean = that match { - case that: Column => that.expr.equals(this.expr) + case that: Column => that.normalizedExpr().equals(this.normalizedExpr()) case _ => false } - override def hashCode: Int = this.expr.hashCode() + override def hashCode: Int = this.normalizedExpr().hashCode() + + private def normalizedExpr(): Expression = expr match { + case a: AttributeReference => Column.stripColumnReferenceMetadata(a) + case _ => expr + } /** Creates a column based on the given expression. */ private def withExpr(newExpr: Expression): Column = new Column(newExpr) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 5d6e5306f174..f450018c1dd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -53,6 +53,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, FileTable} import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -62,6 +63,11 @@ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils private[sql] object Dataset { + val curId = new java.util.concurrent.atomic.AtomicLong() + + val ID_PREFIX = "__dataset_id" + val COL_POS_PREFIX = "__col_position" + def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) // Eagerly bind the encoder so we verify that the encoder matches the underlying @@ -183,6 +189,9 @@ class Dataset[T] private[sql]( @DeveloperApi @Unstable @transient val encoder: Encoder[T]) extends Serializable { + // A globally unique id for this Dataset. + private val id = Dataset.curId.getAndIncrement() + queryExecution.assertAnalyzed() // Note for Spark contributors: if adding or updating any action in `Dataset`, please make sure @@ -873,7 +882,25 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE) + val (joinLeft, joinRight) = prepareJoinPlan(this, right) + Join(joinLeft, joinRight, joinType = Inner, None, JoinHint.NONE) + } + + // Called by `Dataset#join`, to attach the Dataset id to the logical plan, so that we + // can resolve column reference correctly later. See `ResolveDatasetColumnReference`. + private def createPlanWithDatasetId(): LogicalPlan = { + // The alias should start with `SubqueryAlias.HIDDEN_ALIAS_PREFIX`, so that `SubqueryAlias` can + // recognize it and keep the output qualifiers unchanged. + SubqueryAlias(s"${SubqueryAlias.HIDDEN_ALIAS_PREFIX}${Dataset.ID_PREFIX}_$id", logicalPlan) + } + + private def prepareJoinPlan(left: Dataset[_], right: Dataset[_]): (LogicalPlan, LogicalPlan) = { + if (!sparkSession.sessionState.conf.getConf(SQLConf.RESOLVE_DATASET_COLUMN_REFERENCE)) { + // If the config is disabled, do nothing. + (left.logicalPlan, right.logicalPlan) + } else { + (left.createPlanWithDatasetId(), right.createPlanWithDatasetId()) + } } /** @@ -949,10 +976,11 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = { + val (joinLeft, joinRight) = prepareJoinPlan(this, right) // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sparkSession.sessionState.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None, JoinHint.NONE)) + Join(joinLeft, joinRight, joinType = JoinType(joinType), None, JoinHint.NONE)) .analyzed.asInstanceOf[Join] withPlan { @@ -1014,8 +1042,9 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. + val (joinLeft, joinRight) = prepareJoinPlan(this, right) val plan = withPlan( - Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr), JoinHint.NONE)) + Join(joinLeft, joinRight, JoinType(joinType), Some(joinExprs.expr), JoinHint.NONE)) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. @@ -1024,9 +1053,7 @@ class Dataset[T] private[sql]( } // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed - val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed - if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { + if (this.logicalPlan.outputSet.intersect(right.logicalPlan.outputSet).isEmpty) { return withPlan(plan) } @@ -1289,10 +1316,24 @@ class Dataset[T] private[sql]( colRegex(colName) } else { val expr = resolve(colName) - Column(expr) + Column(addDataFrameIdToCol(expr)) } } + // Attach the dataset id and column position to the column reference, so that we can resolve it + // correctly in case of self-join. See `ResolveDatasetColumnReference`. + private def addDataFrameIdToCol(expr: NamedExpression): NamedExpression = expr match { + case a: AttributeReference + if sparkSession.sessionState.conf.getConf(SQLConf.RESOLVE_DATASET_COLUMN_REFERENCE) => + val metadata = new MetadataBuilder() + .withMetadata(a.metadata) + .putLong(Dataset.ID_PREFIX, id) + .putLong(Dataset.COL_POS_PREFIX, logicalPlan.output.indexWhere(a.semanticEquals)) + .build() + a.withMetadata(metadata) + case _ => expr + } + /** * Selects column based on the column name specified as a regex and returns it as [[Column]]. * @group untypedrel @@ -2297,11 +2338,16 @@ class Dataset[T] private[sql]( u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) case Column(expr: Expression) => expr } - val attrs = this.logicalPlan.output - val colsAfterDrop = attrs.filter { attr => - attr != expression - }.map(attr => Column(attr)) - select(colsAfterDrop : _*) + expression match { + case a: Attribute => + val attrs = this.logicalPlan.output + val colsAfterDrop = attrs.filter { attr => + !attr.semanticEquals(a) + }.map(attr => Column(attr)) + select(colsAfterDrop : _*) + + case _ => toDF() + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/ResolveDatasetColumnReference.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/ResolveDatasetColumnReference.scala new file mode 100644 index 000000000000..4bf1b9a1ba3f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/ResolveDatasetColumnReference.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.analysis + +import scala.collection.mutable +import scala.util.Try + +import org.apache.spark.sql.{AnalysisException, Column, Dataset} +import org.apache.spark.sql.catalyst.AliasIdentifier +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Equality, EqualNullSafe, EqualTo} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf + +/** + * Resolves the Dataset column reference by traversing the query plan and finding the plan subtree + * of the Dataset that the column reference belongs to. + * + * Dataset column reference is simply an [[AttributeReference]] that is returned by `Dataset#col`. + * Most of time we don't need to do anything special, as [[AttributeReference]] can point to + * the column precisely. However, in case of self-join, the analyzer generates + * [[AttributeReference]] with new expr IDs for the right side plan of the join. If the Dataset + * column reference points to a column in the right side plan of a self-join, we need to replace it + * with the corresponding newly generated [[AttributeReference]]. + */ +class ResolveDatasetColumnReference(conf: SQLConf) extends Rule[LogicalPlan] { + + // Dataset column reference is an `AttributeReference` with 2 special metadata. + private def isColumnReference(a: AttributeReference): Boolean = { + a.metadata.contains(Dataset.ID_PREFIX) && a.metadata.contains(Dataset.COL_POS_PREFIX) + } + + private case class ColumnReference(datasetId: Long, colPos: Int) + + private def toColumnReference(a: AttributeReference): ColumnReference = { + ColumnReference( + a.metadata.getLong(Dataset.ID_PREFIX), + a.metadata.getLong(Dataset.COL_POS_PREFIX).toInt) + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.RESOLVE_DATASET_COLUMN_REFERENCE)) return plan + + // We always remove the special metadata from `AttributeReference` at the end of this rule, so + // Dataset column reference only exists in the root node via Dataset transformations like + // `Dataset#select`. + val colRefAttrs = plan.expressions.flatMap(_.collect { + case a: AttributeReference if isColumnReference(a) => a + }) + + if (colRefAttrs.isEmpty) { + plan + } else { + val colRefs = colRefAttrs.map(toColumnReference).distinct + // Keeps the mapping between the column reference and the actual attribute it points to. This + // will be used to replace the column references with actual attributes later. + val colRefToActualAttr = new mutable.HashMap[ColumnReference, AttributeReference]() + // Keeps the column references that points to more than one actual attributes. + val ambiguousColRefs = new mutable.HashMap[ColumnReference, Seq[AttributeReference]]() + // We only care about `SubqueryAlias` referring to Datasets which produces the column + // references that we want to resolve here. + val dsIdSet = colRefs.map(_.datasetId).toSet + // If a column reference points to an attribute that is not present in the plan's inputSet, we + // should ignore it as it's invalid. + val inputSet = plan.inputSet + + plan.foreach { + // We only add the special `SubqueryAlias` to attach the dataset id for self-join. After + // self-join resolving, the child of `SubqueryAlias` should have generated new + // `AttributeReference`, and we need to resolve column reference with them. + case SubqueryAlias(DatasetIdAlias(id), child) if dsIdSet.contains(id) => + colRefs.foreach { case ref => + if (id == ref.datasetId) { + if (ref.colPos < 0 || ref.colPos >= child.output.length) { + throw new IllegalStateException("[BUG] Hit an invalid Dataset column reference: " + + s"$ref. Please open a JIRA ticket to report it.") + } else { + val actualAttr = child.output(ref.colPos).asInstanceOf[AttributeReference] + if (inputSet.contains(actualAttr)) { + // Record the ambiguous column references. We will deal with them later. + if (ambiguousColRefs.contains(ref)) { + assert(!colRefToActualAttr.contains(ref)) + ambiguousColRefs(ref) = ambiguousColRefs(ref) :+ actualAttr + } else if (colRefToActualAttr.contains(ref)) { + ambiguousColRefs(ref) = Seq(colRefToActualAttr.remove(ref).get, actualAttr) + } else { + colRefToActualAttr(ref) = actualAttr + } + } + + } + } + } + + case _ => + } + + val deAmbiguousColsRefs = new mutable.HashSet[ColumnReference]() + val newPlan = plan.transformExpressions { + case e @ Equality(a: AttributeReference, b: AttributeReference) + if isColumnReference(a) && isColumnReference(b) && a.sameRef(b) => + val colRefA = toColumnReference(a) + val colRefB = toColumnReference(a) + val maybeActualAttrs = ambiguousColRefs.get(colRefA) + if (colRefA == colRefB && maybeActualAttrs.exists(_.length == 2)) { + deAmbiguousColsRefs += colRefA + if (e.isInstanceOf[EqualTo]) { + EqualTo(maybeActualAttrs.get.head, maybeActualAttrs.get.last) + } else { + EqualNullSafe(maybeActualAttrs.get.head, maybeActualAttrs.get.last) + } + } else { + e + } + + case a: AttributeReference if isColumnReference(a) => + val actualAttr = colRefToActualAttr.getOrElse(toColumnReference(a), a) + // Remove the special metadata from this `AttributeReference`, as the column reference + // resolving is done. + Column.stripColumnReferenceMetadata(actualAttr) + } + + ambiguousColRefs.filterKeys(!deAmbiguousColsRefs.contains(_)).foreach { case (ref, _) => + val originalAttr = colRefAttrs.find(attr => toColumnReference(attr) == ref).get + throw new AnalysisException(s"Column $originalAttr is ambiguous. It's probably " + + "because you joined several Datasets together, and some of these Datasets are the " + + "same. This column points to one of the Datasets but Spark is unable to figure out " + + "which Datasset. Please alias the Datasets with different names via `Dataset.as` " + + "before joining them, and specify the column using qualified name, e.g. " + + """`df.as("a").join(df.as("b"), $"a.id" > $"b.id")`.""") + } + + newPlan + } + } + + object DatasetIdAlias { + def unapply(alias: AliasIdentifier): Option[Long] = { + val expectedPrefix = SubqueryAlias.HIDDEN_ALIAS_PREFIX + Dataset.ID_PREFIX + if (alias.database.isEmpty && alias.identifier.startsWith(expectedPrefix)) { + Try(alias.identifier.drop(expectedPrefix.length + 1).toLong).toOption + } else { + None + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index b2d065274b15..dbb9f26e0b1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} +import org.apache.spark.sql.execution.analysis.ResolveDatasetColumnReference import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.{V2StreamingScanSupportCheck, V2WriteSupportCheck} import org.apache.spark.sql.streaming.StreamingQueryManager @@ -173,7 +174,8 @@ abstract class BaseSessionStateBuilder( customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - PreprocessTableCreation(session) +: + new ResolveDatasetColumnReference(conf) +: + PreprocessTableCreation(session) +: PreprocessTableInsertion(conf) +: DataSourceAnalysis(conf) +: customPostHocResolutionRules diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d89ecc22a7c0..b0dc888f89ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -167,6 +167,21 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null, null, 1, 1, 3) :: Nil ) + // use column reference in `grouping_id` instead of column name + checkAnswer( + courseSales.cube("course", "year") + .agg(grouping_id(courseSales("course"), courseSales("year"))), + Row("Java", 2012, 0) :: + Row("Java", 2013, 0) :: + Row("Java", null, 1) :: + Row("dotNET", 2012, 0) :: + Row("dotNET", 2013, 0) :: + Row("dotNET", null, 1) :: + Row(null, 2012, 2) :: + Row(null, 2013, 2) :: + Row(null, null, 3) :: Nil + ) + intercept[AnalysisException] { courseSales.groupBy().agg(grouping("course")).explain() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index ba120dca712d..f4a70fe4bd98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -117,6 +117,66 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { .collect().toSeq) } + test("join - self join auto resolve ambiguity with case insensitivity") { + val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.join(df, df("key") === df("Key")), + Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) + + checkAnswer( + df.join(df.filter($"value" === "2"), df("key") === df("Key")), + Row(2, "2", 2, "2") :: Nil) + } + + test("SPARK-27547: join - self join without manual alias") { + val df1 = spark.range(3) + val df2 = df1.filter($"id" > 0) + + withSQLConf( + SQLConf.RESOLVE_DATASET_COLUMN_REFERENCE.key -> "false", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // `df1("id") > df2("id")` is always false. + checkAnswer(df1.join(df2, df1("id") > df2("id")), Nil) + + // `df2("id")` actually points to the column of `df1`. + checkAnswer(df1.join(df2).select(df2("id")), Seq(0, 0, 1, 1, 2, 2).map(Row(_))) + + val df3 = df1.filter($"id" <= 2) + // `df2("id") < df3("id")` is always false + checkAnswer(df1.join(df2).join(df3, df2("id") < df3("id")), Nil) + } + + withSQLConf( + SQLConf.RESOLVE_DATASET_COLUMN_REFERENCE.key -> "true", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + + checkAnswer(df1.join(df2, df1("id") > df2("id")), Row(2, 1)) + + checkAnswer(df1.join(df2).select(df2("id")), Seq(1, 2, 1, 2, 1, 2).map(Row(_))) + + val df3 = df1.filter($"id" <= 2) + checkAnswer( + df1.join(df2).join(df3, df2("id") < df3("id")), + Row(0, 1, 2) :: Row(1, 1, 2) :: Row(2, 1, 2) :: Nil) + + checkAnswer( + df3.join(df1.join(df2), df2("id") < df3("id")), + Row(2, 0, 1) :: Row(2, 1, 1) :: Row(2, 2, 1) :: Nil) + + // `df1("id")` is ambiguous here. + intercept[AnalysisException](df1.join(df2).join(df1, df1("id") > df2("id"))) + + // `df2("id")` is not ambiguous. + checkAnswer( + df1.join(df2).join(df1, df2("id") === 2), + Seq(0, 1, 2).flatMap { i => + Seq(0, 1, 2).map { j => + Row(i, 2, j) + } + }) + } + } + test("join - cross join") { val df1 = Seq((1, "1"), (3, "3")).toDF("int", "str") val df2 = Seq((2, "2"), (4, "4")).toDF("int", "str") @@ -153,9 +213,14 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { df.join(df.filter($"value" === "2"), df("key") === df("key")), Row(2, "2", 2, "2") :: Nil) - checkAnswer( - df.join(df, df("key") === df("key") && df("value") === 1), - Row(1, "1", 1, "1") :: Nil) + withSQLConf(SQLConf.RESOLVE_DATASET_COLUMN_REFERENCE.key -> "false") { + // `df("value")` is ambiguous. But under this case, it's OK because the "value" columns from + // the two DataFrames always equal. + // TODO: support it in `ResolveDatasetColumnReference`. + checkAnswer( + df.join(df, df("key") === df("key") && df("value") === 1), + Row(1, "1", 1, "1") :: Nil) + } val left = df.groupBy("key").agg(count("*")) val right = df.groupBy("key").agg(sum("key")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 0e7df8e92197..4e8e716219d0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner +import org.apache.spark.sql.execution.analysis.ResolveDatasetColumnReference import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.{V2StreamingScanSupportCheck, V2WriteSupportCheck} import org.apache.spark.sql.hive.client.HiveClient @@ -77,7 +78,8 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - new DetermineTableStats(session) +: + new ResolveDatasetColumnReference(conf) +: + new DetermineTableStats(session) +: RelationConversions(conf, catalog) +: PreprocessTableCreation(session) +: PreprocessTableInsertion(conf) +: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala deleted file mode 100644 index cdc259d75b13..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.TestHiveSingleton - -class HiveDataFrameJoinSuite extends QueryTest with TestHiveSingleton { - import spark.implicits._ - - // We should move this into SQL package if we make case sensitivity configurable in SQL. - test("join - self join auto resolve ambiguity with case insensitivity") { - val df = Seq((1, "1"), (2, "2")).toDF("key", "value") - checkAnswer( - df.join(df, df("key") === df("Key")), - Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) - - checkAnswer( - df.join(df.filter($"value" === "2"), df("key") === df("Key")), - Row(2, "2", 2, "2") :: Nil) - } - -}