diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index f94fba13da77..efd88d0eb188 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -9,9 +9,9 @@ license: | The ASF licenses this file to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -120,7 +120,7 @@ license: | - The `weekofyear`, `weekday`, `dayofweek`, `date_trunc`, `from_utc_timestamp`, `to_utc_timestamp`, and `unix_timestamp` functions use java.time API for calculation week number of year, day number of week as well for conversion from/to TimestampType values in UTC time zone. - the JDBC options `lowerBound` and `upperBound` are converted to TimestampType/DateType values in the same way as casting strings to TimestampType/DateType values. The conversion is based on Proleptic Gregorian calendar, and time zone defined by the SQL config `spark.sql.session.timeZone`. In Spark version 2.4 and earlier, the conversion is based on the hybrid calendar (Julian + Gregorian) and on default system time zone. - + - Formatting of `TIMESTAMP` and `DATE` literals. - In Spark version 2.4 and earlier, invalid time zone ids are silently ignored and replaced by GMT time zone, for example, in the from_utc_timestamp function. Since Spark 3.0, such time zone ids are rejected, and Spark throws `java.time.DateTimeException`. @@ -143,7 +143,7 @@ license: | - Since Spark 3.0, when Avro files are written with user provided non-nullable schema, even the catalyst schema is nullable, Spark is still able to write the files. However, Spark will throw runtime NPE if any of the records contains null. - - Since Spark 3.0, we use a new protocol for fetching shuffle blocks, for external shuffle service users, we need to upgrade the server correspondingly. Otherwise, we'll get the error message `UnsupportedOperationException: Unexpected message: FetchShuffleBlocks`. If it is hard to upgrade the shuffle service right now, you can still use the old protocol by setting `spark.shuffle.useOldFetchProtocol` to `true`. + - Since Spark 3.0, we use a new protocol for fetching shuffle blocks, for external shuffle service users, we need to upgrade the server correspondingly. Otherwise, we'll get the error message `UnsupportedOperationException: Unexpected message: FetchShuffleBlocks`. If it is hard to upgrade the shuffle service right now, you can still use the old protocol by setting `spark.shuffle.useOldFetchProtocol` to `true`. - Since Spark 3.0, a higher-order function `exists` follows the three-valued boolean logic, i.e., if the `predicate` returns any `null`s and no `true` is obtained, then `exists` will return `null` instead of `false`. For example, `exists(array(1, null, 3), x -> x % 2 == 0)` will be `null`. The previous behaviour can be restored by setting `spark.sql.legacy.arrayExistsFollowsThreeValuedLogic` to `false`. @@ -157,12 +157,14 @@ license: | - The result of `java.lang.Math`'s `log`, `log1p`, `exp`, `expm1`, and `pow` may vary across platforms. In Spark 3.0, the result of the equivalent SQL functions (including related SQL functions like `LOG10`) return values consistent with `java.lang.StrictMath`. In virtually all cases this makes no difference in the return value, and the difference is very small, but may not exactly match `java.lang.Math` on x86 platforms in cases like, for example, `log(3.0)`, whose value varies between `Math.log()` and `StrictMath.log()`. + - Since Spark 3.0, Dataset query fails if it contains ambiguous column reference that is caused by self join. A typical example: `val df1 = ...; val df2 = df1.filter(...);`, then `df1.join(df2, df1("a") > df2("a"))` returns an empty result which is quite confusing. This is because Spark cannot resolve Dataset column references that point to tables being self joined, and `df1("a")` is exactly the same as `df2("a")` in Spark. To restore the behavior before Spark 3.0, you can set `spark.sql.analyzer.failAmbiguousSelfJoin` to `false`. + ## Upgrading from Spark SQL 2.4 to 2.4.1 - The value of `spark.executor.heartbeatInterval`, when specified without units like "30" rather than "30s", was inconsistently interpreted as both seconds and milliseconds in Spark 2.4.0 in different parts of the code. Unitless values are now consistently interpreted as milliseconds. Applications that set values like "30" - need to specify a value with units like "30s" now, to avoid being interpreted as milliseconds; otherwise, + need to specify a value with units like "30s" now, to avoid being interpreted as milliseconds; otherwise, the extremely short interval that results will likely cause applications to fail. - When turning a Dataset to another Dataset, Spark will up cast the fields in the original Dataset to the type of corresponding fields in the target DataSet. In version 2.4 and earlier, this up cast is not very strict, e.g. `Seq("str").toDS.as[Int]` fails, but `Seq("str").toDS.as[Boolean]` works and throw NPE during execution. In Spark 3.0, the up cast is stricter and turning String into something else is not allowed, i.e. `Seq("str").toDS.as[Boolean]` will fail during analysis. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5bf4dc1f045a..4970727b9646 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -323,7 +323,8 @@ class Analyzer( gid: Expression): Expression = { expr transform { case e: GroupingID => - if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) { + if (e.groupByExprs.isEmpty || + e.groupByExprs.map(_.canonicalized) == groupByExprs.map(_.canonicalized)) { Alias(gid, toPrettySQL(e))() } else { throw new AnalysisException( @@ -1164,6 +1165,8 @@ class Analyzer( // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) + // intersect/except will be rewritten to join at the begininng of optimizer. Here we need to + // deduplicate the right side plan, so that we won't produce an invalid self-join later. case i @ Intersect(left, right, _) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) case e @ Except(left, right, _) if !e.duplicateResolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2fede591fc80..ae65f29c3226 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -794,6 +794,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val FAIL_AMBIGUOUS_SELF_JOIN = + buildConf("spark.sql.analyzer.failAmbiguousSelfJoin") + .doc("When true, fail the Dataset query if it contains ambiguous self-join.") + .internal() + .booleanConf + .createWithDefault(true) + // Whether to retain group by columns or not in GroupedData.agg. val DATAFRAME_RETAIN_GROUP_COLUMNS = buildConf("spark.sql.retainGroupColumns") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 5a408b29f933..b0de3c85aaef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -48,6 +48,15 @@ private[sql] object Column { case expr => toPrettySQL(expr) } } + + private[sql] def stripColumnReferenceMetadata(a: AttributeReference): AttributeReference = { + val metadataWithoutId = new MetadataBuilder() + .withMetadata(a.metadata) + .remove(Dataset.DATASET_ID_KEY) + .remove(Dataset.COL_POS_KEY) + .build() + a.withMetadata(metadataWithoutId) + } } /** @@ -144,11 +153,15 @@ class Column(val expr: Expression) extends Logging { override def toString: String = toPrettySQL(expr) override def equals(that: Any): Boolean = that match { - case that: Column => that.expr.equals(this.expr) + case that: Column => that.normalizedExpr() == this.normalizedExpr() case _ => false } - override def hashCode: Int = this.expr.hashCode() + override def hashCode: Int = this.normalizedExpr().hashCode() + + private def normalizedExpr(): Expression = expr transform { + case a: AttributeReference => Column.stripColumnReferenceMetadata(a) + } /** Creates a column based on the given expression. */ private def withExpr(newExpr: Expression): Column = new Column(newExpr) @@ -1008,7 +1021,7 @@ class Column(val expr: Expression) extends Logging { * @since 2.0.0 */ def name(alias: String): Column = withExpr { - expr match { + normalizedExpr() match { case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata)) case other => Alias(other, alias)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ef03a09bba0b..87f4c8f5d949 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -30,7 +30,7 @@ import org.apache.spark.TaskContext import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ -import org.apache.spark.api.python.{PythonEvalType, PythonRDD, SerDeUtil} +import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.api.r.RRDD import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -45,6 +45,7 @@ import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters} import org.apache.spark.sql.execution.command._ @@ -52,6 +53,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, FileTable} import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -61,6 +63,11 @@ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils private[sql] object Dataset { + val curId = new java.util.concurrent.atomic.AtomicLong() + val DATASET_ID_KEY = "__dataset_id" + val COL_POS_KEY = "__col_position" + val DATASET_ID_TAG = TreeNodeTag[Long]("dataset_id") + def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) // Eagerly bind the encoder so we verify that the encoder matches the underlying @@ -182,6 +189,9 @@ class Dataset[T] private[sql]( @DeveloperApi @Unstable @transient val encoder: Encoder[T]) extends Serializable { + // A globally unique id of this Dataset. + private val id = Dataset.curId.getAndIncrement() + queryExecution.assertAnalyzed() // Note for Spark contributors: if adding or updating any action in `Dataset`, please make sure @@ -198,7 +208,7 @@ class Dataset[T] private[sql]( @transient private[sql] val logicalPlan: LogicalPlan = { // For various commands (like DDL) and queries with side effects, we force query execution // to happen right away to let these side effects take place eagerly. - queryExecution.analyzed match { + val plan = queryExecution.analyzed match { case c: Command => LocalRelation(c.output, withAction("command", queryExecution)(_.executeCollect())) case u @ Union(children) if children.forall(_.isInstanceOf[Command]) => @@ -206,6 +216,10 @@ class Dataset[T] private[sql]( case _ => queryExecution.analyzed } + if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN)) { + plan.setTagValue(Dataset.DATASET_ID_TAG, id) + } + plan } /** @@ -1311,11 +1325,29 @@ class Dataset[T] private[sql]( if (sqlContext.conf.supportQuotedRegexColumnName) { colRegex(colName) } else { - val expr = resolve(colName) - Column(expr) + Column(addDataFrameIdToCol(resolve(colName))) } } + // Attach the dataset id and column position to the column reference, so that we can detect + // ambiguous self-join correctly. See the rule `DetectAmbiguousSelfJoin`. + // This must be called before we return a `Column` that contains `AttributeReference`. + // Note that, the metadata added here are only avaiable in the analyzer, as the analyzer rule + // `DetectAmbiguousSelfJoin` will remove it. + private def addDataFrameIdToCol(expr: NamedExpression): NamedExpression = { + val newExpr = expr transform { + case a: AttributeReference + if sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN) => + val metadata = new MetadataBuilder() + .withMetadata(a.metadata) + .putLong(Dataset.DATASET_ID_KEY, id) + .putLong(Dataset.COL_POS_KEY, logicalPlan.output.indexWhere(a.semanticEquals)) + .build() + a.withMetadata(metadata) + } + newExpr.asInstanceOf[NamedExpression] + } + /** * Selects column based on the column name specified as a regex and returns it as [[Column]]. * @group untypedrel @@ -1329,7 +1361,7 @@ class Dataset[T] private[sql]( case ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex) => Column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive)) case _ => - Column(resolve(colName)) + Column(addDataFrameIdToCol(resolve(colName))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala new file mode 100644 index 000000000000..5c3c735f0346 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.analysis + +import scala.collection.mutable + +import org.apache.spark.sql.{AnalysisException, Column, Dataset} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Equality, Expression, ExprId} +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf + +/** + * Detects ambiguous self-joins, so that we can fail the query instead of returning confusing + * results. + * + * Dataset column reference is simply an [[AttributeReference]] that is returned by `Dataset#col`. + * Most of time we don't need to do anything special, as [[AttributeReference]] can point to + * the column precisely. However, in case of self-join, the analyzer generates + * [[AttributeReference]] with new expr IDs for the right side plan of the join. If the Dataset + * column reference points to a column in the right side plan of a self-join, users will get + * unexpected result because the column reference can't match the newly generated + * [[AttributeReference]]. + * + * Note that, this rule removes all the Dataset id related metadata from `AttributeReference`, so + * that they don't exist after analyzer. + */ +class DetectAmbiguousSelfJoin(conf: SQLConf) extends Rule[LogicalPlan] { + + // Dataset column reference is an `AttributeReference` with 2 special metadata. + private def isColumnReference(a: AttributeReference): Boolean = { + a.metadata.contains(Dataset.DATASET_ID_KEY) && a.metadata.contains(Dataset.COL_POS_KEY) + } + + private case class ColumnReference(datasetId: Long, colPos: Int, exprId: ExprId) + + private def toColumnReference(a: AttributeReference): ColumnReference = { + ColumnReference( + a.metadata.getLong(Dataset.DATASET_ID_KEY), + a.metadata.getLong(Dataset.COL_POS_KEY).toInt, + a.exprId) + } + + object LogicalPlanWithDatasetId { + def unapply(p: LogicalPlan): Option[(LogicalPlan, Long)] = { + p.getTagValue(Dataset.DATASET_ID_TAG).map(id => p -> id) + } + } + + object AttrWithCast { + def unapply(expr: Expression): Option[AttributeReference] = expr match { + case Cast(child, _, _) => unapply(child) + case a: AttributeReference => Some(a) + case _ => None + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN)) return plan + + // We always remove the special metadata from `AttributeReference` at the end of this rule, so + // Dataset column reference only exists in the root node via Dataset transformations like + // `Dataset#select`. + val colRefAttrs = plan.expressions.flatMap(_.collect { + case a: AttributeReference if isColumnReference(a) => a + }) + + if (colRefAttrs.nonEmpty) { + val colRefs = colRefAttrs.map(toColumnReference).distinct + val ambiguousColRefs = mutable.HashSet.empty[ColumnReference] + val dsIdSet = colRefs.map(_.datasetId).toSet + + plan.foreach { + case LogicalPlanWithDatasetId(p, id) if dsIdSet.contains(id) => + colRefs.foreach { ref => + if (id == ref.datasetId) { + if (ref.colPos < 0 || ref.colPos >= p.output.length) { + throw new IllegalStateException("[BUG] Hit an invalid Dataset column reference: " + + s"$ref. Please open a JIRA ticket to report it.") + } else { + // When self-join happens, the analyzer asks the right side plan to generate + // attributes with new exprIds. If a plan of a Dataset outputs an attribute which + // is referred by a column reference, and this attribute has different exprId than + // the attribute of column reference, then the column reference is ambiguous, as it + // refers to a column that gets regenerated by self-join. + val actualAttr = p.output(ref.colPos).asInstanceOf[AttributeReference] + if (actualAttr.exprId != ref.exprId) { + ambiguousColRefs += ref + } + } + } + } + + case _ => + } + + val ambiguousAttrs: Seq[AttributeReference] = plan match { + case Join( + LogicalPlanWithDatasetId(_, leftId), + LogicalPlanWithDatasetId(_, rightId), + _, condition, _) => + // If we are dealing with root join node, we need to take care of SPARK-6231: + // 1. We can de-ambiguous `df("col") === df("col")` in the join condition. + // 2. There is no ambiguity in direct self join like + // `df.join(df, df("col") === 1)`, because it doesn't matter which side the + // column comes from. + def getAmbiguousAttrs(expr: Expression): Seq[AttributeReference] = expr match { + case Equality(AttrWithCast(a), AttrWithCast(b)) if a.sameRef(b) => + Nil + case Equality(AttrWithCast(a), b) if leftId == rightId && b.foldable => + Nil + case Equality(a, AttrWithCast(b)) if leftId == rightId && a.foldable => + Nil + case a: AttributeReference => + if (isColumnReference(a)) { + val colRef = toColumnReference(a) + if (ambiguousColRefs.contains(colRef)) Seq(a) else Nil + } else { + Nil + } + case _ => expr.children.flatMap(getAmbiguousAttrs) + } + condition.toSeq.flatMap(getAmbiguousAttrs) + + case _ => ambiguousColRefs.toSeq.map { ref => + colRefAttrs.find(attr => toColumnReference(attr) == ref).get + } + } + + if (ambiguousAttrs.nonEmpty) { + throw new AnalysisException(s"Column ${ambiguousAttrs.mkString(", ")} are ambiguous. " + + "It's probably because you joined several Datasets together, and some of these " + + "Datasets are the same. This column points to one of the Datasets but Spark is unable " + + "to figure out which one. Please alias the Datasets with different names via " + + "`Dataset.as` before joining them, and specify the column using qualified name, e.g. " + + """`df.as("a").join(df.as("b"), $"a.id" > $"b.id")`. You can also set """ + + s"${SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key} to false to disable this check.") + } + } + + plan.transformExpressions { + case a: AttributeReference if isColumnReference(a) => + // Remove the special metadata from this `AttributeReference`, as the detection is done. + Column.stripColumnReferenceMetadata(a) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index b05a5dfea3ff..16a63793e931 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ColumnarRule, QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} +import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.{V2StreamingScanSupportCheck, V2WriteSupportCheck} import org.apache.spark.sql.streaming.StreamingQueryManager @@ -174,7 +175,8 @@ abstract class BaseSessionStateBuilder( customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - PreprocessTableCreation(session) +: + new DetectAmbiguousSelfJoin(conf) +: + PreprocessTableCreation(session) +: PreprocessTableInsertion(conf) +: DataSourceAnalysis(conf) +: customPostHocResolutionRules diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e49ef012f5eb..c56c93f70857 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -167,6 +167,21 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null, null, 1, 1, 3) :: Nil ) + // use column reference in `grouping_id` instead of column name + checkAnswer( + courseSales.cube("course", "year") + .agg(grouping_id(courseSales("course"), courseSales("year"))), + Row("Java", 2012, 0) :: + Row("Java", 2013, 0) :: + Row("Java", null, 1) :: + Row("dotNET", 2012, 0) :: + Row("dotNET", 2013, 0) :: + Row("dotNET", null, 1) :: + Row(null, 2012, 2) :: + Row(null, 2013, 2) :: + Row(null, null, 3) :: Nil + ) + intercept[AnalysisException] { courseSales.groupBy().agg(grouping("course")).explain() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index ba120dca712d..dc7928fde779 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -98,25 +98,6 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(3, "3", 4) :: Nil) } - test("join - join using self join") { - val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") - - // self join - checkAnswer( - df.join(df, "int"), - Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: Nil) - } - - test("join - self join") { - val df1 = testData.select(testData("key")).as('df1) - val df2 = testData.select(testData("key")).as('df2) - - checkAnswer( - df1.join(df2, $"df1.key" === $"df2.key"), - sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") - .collect().toSeq) - } - test("join - cross join") { val df1 = Seq((1, "1"), (3, "3")).toDF("int", "str") val df2 = Seq((2, "2"), (4, "4")).toDF("int", "str") @@ -132,38 +113,6 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(4, "4", 1, "1") :: Row(4, "4", 3, "3") :: Nil) } - test("join - using aliases after self join") { - val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") - checkAnswer( - df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), - Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) - - checkAnswer( - df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").count(), - Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) - } - - test("[SPARK-6231] join - self join auto resolve ambiguity") { - val df = Seq((1, "1"), (2, "2")).toDF("key", "value") - checkAnswer( - df.join(df, df("key") === df("key")), - Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) - - checkAnswer( - df.join(df.filter($"value" === "2"), df("key") === df("key")), - Row(2, "2", 2, "2") :: Nil) - - checkAnswer( - df.join(df, df("key") === df("key") && df("value") === 1), - Row(1, "1", 1, "1") :: Nil) - - val left = df.groupBy("key").agg(count("*")) - val right = df.groupBy("key").agg(sum("key")) - checkAnswer( - left.join(right, left("key") === right("key")), - Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) - } - test("broadcast join hint using broadcast function") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala new file mode 100644 index 000000000000..92f1e4306c5b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions.{count, sum} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("join - join using self join") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + + // self join + checkAnswer( + df.join(df, "int"), + Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: Nil) + } + + test("join - self join") { + val df1 = testData.select(testData("key")).as('df1) + val df2 = testData.select(testData("key")).as('df2) + + checkAnswer( + df1.join(df2, $"df1.key" === $"df2.key"), + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + .collect().toSeq) + } + + test("join - self join auto resolve ambiguity with case insensitivity") { + val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.join(df, df("key") === df("Key")), + Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) + + checkAnswer( + df.join(df.filter($"value" === "2"), df("key") === df("Key")), + Row(2, "2", 2, "2") :: Nil) + } + + test("join - using aliases after self join") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + } + + test("[SPARK-6231] join - self join auto resolve ambiguity") { + val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.join(df, df("key") === df("key")), + Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) + + checkAnswer( + df.join(df.filter($"value" === "2"), df("key") === df("key")), + Row(2, "2", 2, "2") :: Nil) + + checkAnswer( + df.join(df, df("key") === df("key") && df("value") === 1), + Row(1, "1", 1, "1") :: Nil) + + val left = df.groupBy("key").agg(count("*")) + val right = df.groupBy("key").agg(sum("key")) + checkAnswer( + left.join(right, left("key") === right("key")), + Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) + } + + private def assertAmbiguousSelfJoin(df: => DataFrame): Unit = { + val e = intercept[AnalysisException](df) + assert(e.message.contains("ambiguous")) + } + + test("SPARK-28344: fail ambiguous self join - column ref in join condition") { + val df1 = spark.range(3) + val df2 = df1.filter($"id" > 0) + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "false", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // `df1("id") > df2("id")` is always false. + checkAnswer(df1.join(df2, df1("id") > df2("id")), Nil) + + // Alias the dataframe and use qualified column names can fix ambiguous self-join. + val aliasedDf1 = df1.alias("left") + val aliasedDf2 = df2.as("right") + checkAnswer( + aliasedDf1.join(aliasedDf2, $"left.id" > $"right.id"), + Seq(Row(2, 1))) + } + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "true", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertAmbiguousSelfJoin(df1.join(df2, df1("id") > df2("id"))) + } + } + + test("SPARK-28344: fail ambiguous self join - Dataset.colRegex as column ref") { + val df1 = spark.range(3) + val df2 = df1.filter($"id" > 0) + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "true", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertAmbiguousSelfJoin(df1.join(df2, df1.colRegex("id") > df2.colRegex("id"))) + } + } + + test("SPARK-28344: fail ambiguous self join - Dataset.col with nested field") { + val df1 = spark.read.json(Seq("""{"a": {"b": 1, "c": 1}}""").toDS()) + val df2 = df1.filter($"a.b" > 0) + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "true", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertAmbiguousSelfJoin(df1.join(df2, df1("a.b") > df2("a.c"))) + } + } + + test("SPARK-28344: fail ambiguous self join - column ref in Project") { + val df1 = spark.range(3) + val df2 = df1.filter($"id" > 0) + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "false", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // `df2("id")` actually points to the column of `df1`. + checkAnswer(df1.join(df2).select(df2("id")), Seq(0, 0, 1, 1, 2, 2).map(Row(_))) + + // Alias the dataframe and use qualified column names can fix ambiguous self-join. + val aliasedDf1 = df1.alias("left") + val aliasedDf2 = df2.as("right") + checkAnswer( + aliasedDf1.join(aliasedDf2).select($"right.id"), + Seq(1, 1, 1, 2, 2, 2).map(Row(_))) + } + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "true", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertAmbiguousSelfJoin(df1.join(df2).select(df2("id"))) + } + } + + test("SPARK-28344: fail ambiguous self join - join three tables") { + val df1 = spark.range(3) + val df2 = df1.filter($"id" > 0) + val df3 = df1.filter($"id" <= 2) + val df4 = spark.range(1) + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "false", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // `df2("id") < df3("id")` is always false + checkAnswer(df1.join(df2).join(df3, df2("id") < df3("id")), Nil) + // `df2("id")` actually points to the column of `df1`. + checkAnswer( + df1.join(df4).join(df2).select(df2("id")), + Seq(0, 0, 1, 1, 2, 2).map(Row(_))) + // `df4("id")` is not ambiguous. + checkAnswer( + df1.join(df4).join(df2).select(df4("id")), + Seq(0, 0, 0, 0, 0, 0).map(Row(_))) + + // Alias the dataframe and use qualified column names can fix ambiguous self-join. + val aliasedDf1 = df1.alias("x") + val aliasedDf2 = df2.as("y") + val aliasedDf3 = df3.as("z") + checkAnswer( + aliasedDf1.join(aliasedDf2).join(aliasedDf3, $"y.id" < $"z.id"), + Seq(Row(0, 1, 2), Row(1, 1, 2), Row(2, 1, 2))) + checkAnswer( + aliasedDf1.join(df4).join(aliasedDf2).select($"y.id"), + Seq(1, 1, 1, 2, 2, 2).map(Row(_))) + } + + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN.key -> "true", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertAmbiguousSelfJoin(df1.join(df2).join(df3, df2("id") < df3("id"))) + assertAmbiguousSelfJoin(df1.join(df4).join(df2).select(df2("id"))) + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 2fa108825982..cd609002410a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner +import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.{V2StreamingScanSupportCheck, V2WriteSupportCheck} import org.apache.spark.sql.hive.client.HiveClient @@ -78,7 +79,8 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - new DetermineTableStats(session) +: + new DetectAmbiguousSelfJoin(conf) +: + new DetermineTableStats(session) +: RelationConversions(conf, catalog) +: PreprocessTableCreation(session) +: PreprocessTableInsertion(conf) +: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala deleted file mode 100644 index cdc259d75b13..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.TestHiveSingleton - -class HiveDataFrameJoinSuite extends QueryTest with TestHiveSingleton { - import spark.implicits._ - - // We should move this into SQL package if we make case sensitivity configurable in SQL. - test("join - self join auto resolve ambiguity with case insensitivity") { - val df = Seq((1, "1"), (2, "2")).toDF("key", "value") - checkAnswer( - df.join(df, df("key") === df("Key")), - Row(1, "1", 1, "1") :: Row(2, "2", 2, "2") :: Nil) - - checkAnswer( - df.join(df.filter($"value" === "2"), df("key") === df("Key")), - Row(2, "2", 2, "2") :: Nil) - } - -}