diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 6e27192ead32..524e216e789b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier} import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.MetadataBuilder trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { @@ -518,26 +519,126 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { case _ => e } + private def stripColumnReferenceMetadata(a: AttributeReference): AttributeReference = { + val metadataWithoutId = new MetadataBuilder() + .withMetadata(a.metadata) + .remove(LogicalPlan.DATASET_ID_KEY) + .remove(LogicalPlan.COL_POS_KEY) + .build() + a.withMetadata(metadataWithoutId) + } + + private def resolveUsingDatasetId( + ua: UnresolvedAttribute, + left: LogicalPlan, + right: LogicalPlan, + datasetId: Long): Option[NamedExpression] = { + def findUnaryNodeMatchingTagId(lp: LogicalPlan): Option[(LogicalPlan, Int)] = { + var currentLp = lp + var depth = 0 + while (true) { + if (currentLp.getTagValue(LogicalPlan.DATASET_ID_TAG).exists( + _.contains(datasetId))) { + return Option(currentLp, depth) + } else { + if (currentLp.children.size == 1) { + currentLp = currentLp.children.head + } else { + // leaf node or node is a binary node + return None + } + } + depth += 1 + } + None + } + + val leftDefOpt = findUnaryNodeMatchingTagId(left) + val rightDefOpt = findUnaryNodeMatchingTagId(right) + val resolveOnAttribs = (leftDefOpt, rightDefOpt) match { + + case (None, Some((lp, _))) => lp.output + + case (Some((lp, _)), None) => lp.output + + case (Some((lp1, depth1)), Some((lp2, depth2))) => if (depth1 == depth2) { + Seq.empty + } else if (depth1 < depth2) { + lp1.output + } else { + lp2.output + } + + case _ => Seq.empty + } + if (resolveOnAttribs.isEmpty) { + None + } else { + AttributeSeq.fromNormalOutput(resolveOnAttribs).resolve(Seq(ua.name), conf.resolver) + } + } + private def resolveDataFrameColumn( u: UnresolvedAttribute, q: Seq[LogicalPlan]): Option[NamedExpression] = { - val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG) - if (planIdOpt.isEmpty) return None - val planId = planIdOpt.get - logDebug(s"Extract plan_id $planId from $u") - - val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).nonEmpty - - val (resolved, matched) = resolveDataFrameColumnByPlanId( - u, planId, isMetadataAccess, q, 0) - if (!matched) { - // Can not find the target plan node with plan id, e.g. - // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) - // df2 = spark.createDataFrame([Row(a = 1, b = 2)]]) - // df1.select(df2.a) <- illegal reference df2.a - throw QueryCompilationErrors.cannotResolveDataFrameColumn(u) + val origAttrOpt = u.getTagValue(LogicalPlan.UNRESOLVED_ATTRIBUTE_MD_TAG) + val resolvedOptWithDatasetId = if (origAttrOpt.isDefined) { + val md = origAttrOpt.get.metadata + if (md.contains(LogicalPlan.DATASET_ID_KEY)) { + val did = md.getLong(LogicalPlan.DATASET_ID_KEY) + val resolved = if (q.size == 1) { + val binaryNodeOpt = q.head.collectFirst { + case bn: BinaryNode => bn + } + binaryNodeOpt.flatMap(bn => resolveUsingDatasetId(u, bn.left, bn.right, did)) + } else if (q.size == 2) { + resolveUsingDatasetId(u, q(0), q(1), did) + } else { + None + } + if (resolved.isEmpty) { + if (conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) { + origAttrOpt + } else { + origAttrOpt.map(stripColumnReferenceMetadata) + } + } else { + resolved + } + } else { + origAttrOpt + } + } else { + None + } + val resolvedOpt = if (resolvedOptWithDatasetId.isDefined) { + resolvedOptWithDatasetId + } + else { + val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG) + if (planIdOpt.isEmpty) { + None + } else { + val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG) + if (planIdOpt.isEmpty) return None + val planId = planIdOpt.get + logDebug(s"Extract plan_id $planId from $u") + + val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).nonEmpty + + val (resolved, matched) = resolveDataFrameColumnByPlanId( + u, planId, isMetadataAccess, q, 0) + if (!matched) { + // Can not find the target plan node with plan id, e.g. + // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) + // df2 = spark.createDataFrame([Row(a = 1, b = 2)]]) + // df1.select(df2.a) <- illegal reference df2.a + throw QueryCompilationErrors.cannotResolveDataFrameColumn(u) + } + resolved.map(_._1) + } } - resolved.map(_._1) + resolvedOpt } private def resolveDataFrameColumnByPlanId( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index b989233da674..7491d70ebfb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical +import scala.collection.mutable + import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException @@ -30,7 +32,6 @@ import org.apache.spark.sql.catalyst.util.MetadataColumnHelper import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.StructType - abstract class LogicalPlan extends QueryPlan[LogicalPlan] with AnalysisHelper @@ -199,6 +200,10 @@ object LogicalPlan { // to the old code path. private[spark] val PLAN_ID_TAG = TreeNodeTag[Long]("plan_id") private[spark] val IS_METADATA_COL = TreeNodeTag[Unit]("is_metadata_col") + private[spark] val DATASET_ID_TAG = TreeNodeTag[mutable.HashSet[Long]]("dataset_id") + private[spark] val UNRESOLVED_ATTRIBUTE_MD_TAG = TreeNodeTag[AttributeReference]("orig-attr") + private[spark] val DATASET_ID_KEY = "__dataset_id" + private[spark] val COL_POS_KEY = "__col_position" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c29fd968fc19..b1f58baae13b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -38,6 +38,7 @@ import org.apache.spark.api.r.RRDD import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.resource.ResourceProfile +import org.apache.spark.sql.Dataset.{DATASET_ID_KEY, DATASET_ID_TAG} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QueryPlanningTracker, ScalaReflection, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation @@ -47,7 +48,7 @@ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern} +import org.apache.spark.sql.catalyst.trees.TreePattern import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId @@ -71,9 +72,9 @@ import org.apache.spark.util.Utils private[sql] object Dataset { val curId = new java.util.concurrent.atomic.AtomicLong() - val DATASET_ID_KEY = "__dataset_id" - val COL_POS_KEY = "__col_position" - val DATASET_ID_TAG = TreeNodeTag[HashSet[Long]]("dataset_id") + val DATASET_ID_KEY = LogicalPlan.DATASET_ID_KEY + val COL_POS_KEY = LogicalPlan.COL_POS_KEY + val DATASET_ID_TAG = LogicalPlan.DATASET_ID_TAG def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) @@ -222,11 +223,9 @@ class Dataset[T] private[sql]( @transient private[sql] val logicalPlan: LogicalPlan = { val plan = queryExecution.commandExecuted - if (sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) { - val dsIds = plan.getTagValue(Dataset.DATASET_ID_TAG).getOrElse(new HashSet[Long]) - dsIds.add(id) - plan.setTagValue(Dataset.DATASET_ID_TAG, dsIds) - } + val dsIds = plan.getTagValue(Dataset.DATASET_ID_TAG).getOrElse(new HashSet[Long]) + dsIds.add(id) + plan.setTagValue(Dataset.DATASET_ID_TAG, dsIds) plan } @@ -1146,9 +1145,8 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(logicalPlan, right.logicalPlan, - JoinType(joinType), joinExprs.map(_.expr), JoinHint.NONE)) - .queryExecution.analyzed.asInstanceOf[Join] + tryAmbiguityResolution(right, joinExprs, joinType) + ).queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) { @@ -1169,6 +1167,36 @@ class Dataset[T] private[sql]( JoinWith.resolveSelfJoinCondition(sparkSession.sessionState.analyzer.resolver, plan) } + private def tryAmbiguityResolution( + right: Dataset[_], + joinExprs: Option[Column], + joinType: String) = { + val planPart1 = withPlan( + Join(logicalPlan, right.logicalPlan, + JoinType(joinType), None, JoinHint.NONE)).queryExecution.analyzed.asInstanceOf[Join] + + val leftTagIdMap = planPart1.left.getTagValue(DATASET_ID_TAG) + val rightTagIdMap = planPart1.right.getTagValue(DATASET_ID_TAG) + + val joinExprsRectified = joinExprs.map(_.expr transformUp { + case attr: AttributeReference if attr.metadata.contains(DATASET_ID_KEY) => + // For attribute to remain attribute and not to UnResolved, only one leg should be tru + val leftLegWrong = isIncorrectlyResolved(attr, planPart1.left.outputSet, + leftTagIdMap.getOrElse(HashSet.empty[Long])) + val rightLegWrong = isIncorrectlyResolved(attr, planPart1.right.outputSet, + rightTagIdMap.getOrElse(HashSet.empty[Long])) + if (!planPart1.outputSet.contains(attr) || leftLegWrong || rightLegWrong) { + val ua = UnresolvedAttribute(Seq(attr.name)) + ua.copyTagsFrom(attr) + ua.setTagValue(LogicalPlan.UNRESOLVED_ATTRIBUTE_MD_TAG, attr) + ua + } else { + attr + } + }) + Join(planPart1.left, planPart1.right, JoinType(joinType), joinExprsRectified, JoinHint.NONE) + } + /** * Join with another `DataFrame`, using the given join expression. The following performs * a full outer join between `df1` and `df2`. @@ -1176,7 +1204,7 @@ class Dataset[T] private[sql]( * {{{ * // Scala: * import org.apache.spark.sql.functions._ - * df1.join(df2, $"df1Key" === $"df2Key", "outer") + * df1.join(df2, $"df1Key" === $"df2Key", "outer" * * // Java: * import static org.apache.spark.sql.functions.*; @@ -1305,11 +1333,23 @@ class Dataset[T] private[sql]( case a: AttributeReference if logicalPlan.outputSet.contains(a) => val index = logicalPlan.output.indexWhere(_.exprId == a.exprId) joined.left.output(index) + + case a: AttributeReference if a.metadata.contains(Dataset.DATASET_ID_KEY) => + val ua = UnresolvedAttribute(Seq(a.name)) + ua.copyTagsFrom(a) + ua.setTagValue(LogicalPlan.UNRESOLVED_ATTRIBUTE_MD_TAG, a) + ua } val rightAsOfExpr = rightAsOf.expr.transformUp { case a: AttributeReference if other.logicalPlan.outputSet.contains(a) => val index = other.logicalPlan.output.indexWhere(_.exprId == a.exprId) joined.right.output(index) + + case a: AttributeReference if a.metadata.contains(Dataset.DATASET_ID_KEY) => + val ua = UnresolvedAttribute(Seq(a.name)) + ua.copyTagsFrom(a) + ua.setTagValue(LogicalPlan.UNRESOLVED_ATTRIBUTE_MD_TAG, a) + ua } withPlan { AsOfJoin( @@ -1482,8 +1522,8 @@ class Dataset[T] private[sql]( // `DetectAmbiguousSelfJoin` will remove it. private def addDataFrameIdToCol(expr: NamedExpression): NamedExpression = { val newExpr = expr transform { - case a: AttributeReference - if sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) => + case a: AttributeReference => + // if sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) => val metadata = new MetadataBuilder() .withMetadata(a.metadata) .putLong(Dataset.DATASET_ID_KEY, id) @@ -1573,7 +1613,17 @@ class Dataset[T] private[sql]( case other => other } - Project(untypedCols.map(_.named), logicalPlan) + val inputForProj = logicalPlan.outputSet + val namedExprs = untypedCols.map(ne => (ne.named transformUp { + case attr: AttributeReference if attr.metadata.contains(DATASET_ID_KEY) && + (!inputForProj.contains(attr) || + isIncorrectlyResolved(attr, inputForProj, HashSet(id))) => + val ua = UnresolvedAttribute(Seq(attr.name)) + ua.copyTagsFrom(attr) + ua.setTagValue(LogicalPlan.UNRESOLVED_ATTRIBUTE_MD_TAG, attr) + ua + }).asInstanceOf[NamedExpression]) + Project(namedExprs, logicalPlan) } /** @@ -4221,6 +4271,31 @@ class Dataset[T] private[sql]( queryExecution.analyzed.semanticHash() } + private def isIncorrectlyResolved( + attr: AttributeReference, + input: AttributeSet, + dataSetIdOfInput: HashSet[Long]): Boolean = { + val attrDatasetIdOpt = if (attr.metadata.contains(DATASET_ID_KEY)) { + Option(attr.metadata.getLong(DATASET_ID_KEY)) + } else { + None + } + attrDatasetIdOpt.forall(attrId => { + val matchingInputset = input.filter(_.canonicalized == attr.canonicalized) + if (matchingInputset.isEmpty) { + true + } else { + matchingInputset.forall(x => { + if (x.metadata.contains(DATASET_ID_KEY)) { + attrId != x.metadata.getLong(DATASET_ID_KEY) + } else { + !dataSetIdOfInput.contains(attrId) + } + }) + } + }) + } + //////////////////////////////////////////////////////////////////////////// // For Python API //////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala index 280eb095dc75..4d1953f21890 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.jdk.CollectionConverters._ +import org.apache.spark.sql.catalyst.plans.logical.AsOfJoin import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSparkSession @@ -173,4 +174,23 @@ class DataFrameAsOfJoinSuite extends QueryTest ) ) } + + test("SPARK-47217: Dedup of relations can impact projected columns resolution") { + val (df1, df2) = prepareForAsOfJoin() + val join1 = df1.join(df2, df1.col("a") === df2.col("a")).select(df2.col("a"), df1.col("b"), + df2.col("b"), df1.col("a").as("aa")) + + // In stock spark this would throw ambiguous column exception, even though it is not ambiguous + val asOfjoin2 = join1.joinAsOf( + df1, df1.col("a"), join1.col("a"), usingColumns = Seq.empty, + joinType = "left", tolerance = null, allowExactMatches = false, direction = "nearest") + + asOfjoin2.queryExecution.assertAnalyzed() + + val testDf = asOfjoin2.select(df1.col("a")) + val analyzed = testDf.queryExecution.analyzed + val attributeRefToCheck = analyzed.output.head + assert(analyzed.children(0).asInstanceOf[AsOfJoin].right.outputSet. + contains(attributeRefToCheck)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 7dc40549a17b..b7fb57b32595 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql import org.apache.spark.api.python.PythonEvalType -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder} -import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, BinaryExpression, PythonUDF, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, Join, Project, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.{count, explode, sum, year} import org.apache.spark.sql.internal.SQLConf @@ -97,76 +97,100 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assert(e.message.contains("ambiguous")) } - test("SPARK-28344: fail ambiguous self join - column ref in join condition") { - val df1 = spark.range(3) - val df2 = df1.filter($"id" > 0) - - withSQLConf( - SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "false", - SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - // `df1("id") > df2("id")` is always false. - checkAnswer(df1.join(df2, df1("id") > df2("id")), Nil) - - // Alias the dataframe and use qualified column names can fix ambiguous self-join. - val aliasedDf1 = df1.alias("left") - val aliasedDf2 = df2.as("right") - checkAnswer( - aliasedDf1.join(aliasedDf2, $"left.id" > $"right.id"), - Seq(Row(2, 1))) + private def assertCorrectResolution( + df: => DataFrame, + leftResolution: Resolution.Resolution, + rightResolution: Resolution.Resolution): Unit = { + val join = df.queryExecution.analyzed.asInstanceOf[Join] + val binaryCondition = join.condition.get.asInstanceOf[BinaryExpression] + leftResolution match { + case Resolution.LeftConditionToLeftLeg => + assert(join.left.outputSet.contains(binaryCondition.left.references.head)) + case Resolution.LeftConditionToRightLeg => + assert(join.right.outputSet.contains(binaryCondition.left.references.head)) } - withSQLConf( - SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", - SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - assertAmbiguousSelfJoin(df1.join(df2, df1("id") > df2("id"))) + rightResolution match { + case Resolution.RightConditionToLeftLeg => + assert(join.left.outputSet.contains(binaryCondition.right.references.head)) + case Resolution.RightConditionToRightLeg => + assert(join.right.outputSet.contains(binaryCondition.right.references.head)) } } - test("SPARK-28344: fail ambiguous self join - Dataset.colRegex as column ref") { - val df1 = spark.range(3) - val df2 = df1.filter($"id" > 0) - - withSQLConf( - SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", - SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - assertAmbiguousSelfJoin(df1.join(df2, df1.colRegex("id") > df2.colRegex("id"))) - } + test("SPARK-28344: NOT AN ambiguous self join - column ref in join condition") { + Seq(true, false).foreach(fail => { + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString, + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + val df1 = spark.range(3) + val df2 = df1.filter($"id" > 0) + // `df1("id") > df2("id")` is always false. + checkAnswer(df1.join(df2, df1("id") > df2("id")), Seq(Row(2, 1))) + assertCorrectResolution(df1.join(df2, df1("id") > df2("id")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + + // Alias the dataframe and use qualified column names to eliminate all possibilities + // of ambiguity in self-join. + val aliasedDf1 = df1.alias("left") + val aliasedDf2 = df2.as("right") + checkAnswer( + aliasedDf1.join(aliasedDf2, $"left.id" > $"right.id"), + Seq(Row(2, 1))) + assertCorrectResolution(aliasedDf1.join(aliasedDf2, $"left.id" > $"right.id"), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + } + }) } - test("SPARK-28344: fail ambiguous self join - Dataset.col with nested field") { - val df1 = spark.read.json(Seq("""{"a": {"b": 1, "c": 1}}""").toDS()) - val df2 = df1.filter($"a.b" > 0) - - withSQLConf( - SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", - SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - assertAmbiguousSelfJoin(df1.join(df2, df1("a.b") > df2("a.c"))) - } + test("SPARK-28344: Not AN ambiguous self join - Dataset.colRegex as column ref") { + Seq(true, false).foreach(fail => { + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString, + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + val df1 = spark.range(3) + val df2 = df1.filter($"id" > 0) + assertCorrectResolution(df1.join(df2, df1.colRegex("id") > df2.colRegex("id")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + } + }) } - test("SPARK-28344: fail ambiguous self join - column ref in Project") { - val df1 = spark.range(3) - val df2 = df1.filter($"id" > 0) - - withSQLConf( - SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "false", - SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - // `df2("id")` actually points to the column of `df1`. - checkAnswer(df1.join(df2).select(df2("id")), Seq(0, 0, 1, 1, 2, 2).map(Row(_))) - - // Alias the dataframe and use qualified column names can fix ambiguous self-join. - val aliasedDf1 = df1.alias("left") - val aliasedDf2 = df2.as("right") - checkAnswer( - aliasedDf1.join(aliasedDf2).select($"right.id"), - Seq(1, 1, 1, 2, 2, 2).map(Row(_))) - } + test("SPARK-28344: Not An ambiguous self join - Dataset.col with nested field") { + Seq(true, false).foreach(fail => { + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString, + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + val df1 = spark.read.json(Seq("""{"a": {"b": 1, "c": 1}}""").toDS()) + val df2 = df1.filter($"a.b" > 0) + assertCorrectResolution(df1.join(df2, df1("a.b") > df2("a.c")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + } + }) + } - withSQLConf( - SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", - SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - assertAmbiguousSelfJoin(df1.join(df2).select(df2("id"))) - } + test("SPARK-28344: Not an ambiguous - column ref in Project") { + Seq(true, false).foreach(fail => { + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString, + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + val df1 = spark.range(3) + val df2 = df1.filter($"id" > 0) + // `df2("id")` actually points to the column of `df1`. + checkAnswer(df1.join(df2).select(df2("id")), Seq(1, 1, 1, 2, 2, 2).map(Row(_))) + + // Alias the dataframe and use qualified column names can fix ambiguous self-join. + val aliasedDf1 = df1.alias("left") + val aliasedDf2 = df2.as("right") + checkAnswer( + aliasedDf1.join(aliasedDf2).select($"right.id"), + Seq(1, 1, 1, 2, 2, 2).map(Row(_))) + + val proj1 = df1.join(df2).select(df2("id")).queryExecution.analyzed.asInstanceOf[Project] + val join1 = proj1.child.asInstanceOf[Join] + assert(proj1.projectList(0).references.subsetOf(join1.right.outputSet)) + } + }) } test("SPARK-28344: fail ambiguous self join - join three tables") { @@ -178,12 +202,13 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { withSQLConf( SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "false", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - // `df2("id") < df3("id")` is always false - checkAnswer(df1.join(df2).join(df3, df2("id") < df3("id")), Nil) + // Here df3("id") is unambiguous, df2("id") is ambiguous. default resolves to df1 + checkAnswer(df1.join(df2).join(df3, df2("id") < df3("id")), + Seq(Row(0, 1, 1), Row(0, 2, 1), Row(0, 1, 2), Row(0, 2, 2), Row(1, 1, 2), Row(1, 2, 2))) // `df2("id")` actually points to the column of `df1`. - checkAnswer( + checkAnswer( df1.join(df4).join(df2).select(df2("id")), - Seq(0, 0, 1, 1, 2, 2).map(Row(_))) + Seq(1, 2, 1, 2, 1, 2).map(Row(_))) // `df4("id")` is not ambiguous. checkAnswer( df1.join(df4).join(df2).select(df4("id")), @@ -205,24 +230,30 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { assertAmbiguousSelfJoin(df1.join(df2).join(df3, df2("id") < df3("id"))) - assertAmbiguousSelfJoin(df1.join(df4).join(df2).select(df2("id"))) + + val proj1 = df1.join(df4).join(df2).select(df2("id")).queryExecution.analyzed. + asInstanceOf[Project] + val join1 = proj1.child.asInstanceOf[Join] + assert(proj1.projectList(0).references.subsetOf(join1.right.outputSet)) } } test("SPARK-28344: don't fail if there is no ambiguous self join") { - withSQLConf( - SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true") { - val df = Seq(1, 1, 2, 2).toDF("a") - val w = Window.partitionBy(df("a")) - checkAnswer( - df.select(df("a").alias("x"), sum(df("a")).over(w)), - Seq((1, 2), (1, 2), (2, 4), (2, 4)).map(Row.fromTuple)) - - val joined = df.join(spark.range(1)).select($"a") - checkAnswer( - joined.select(joined("a").alias("x"), sum(joined("a")).over(w)), - Seq((1, 2), (1, 2), (2, 4), (2, 4)).map(Row.fromTuple)) - } + Seq(true, false).foreach(fail => { + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString) { + val df = Seq(1, 1, 2, 2).toDF("a") + val w = Window.partitionBy(df("a")) + checkAnswer( + df.select(df("a").alias("x"), sum(df("a")).over(w)), + Seq((1, 2), (1, 2), (2, 4), (2, 4)).map(Row.fromTuple)) + + val joined = df.join(spark.range(1)).select($"a") + checkAnswer( + joined.select(joined("a").alias("x"), sum(joined("a")).over(w)), + Seq((1, 2), (1, 2), (2, 4), (2, 4)).map(Row.fromTuple)) + } + }) } test("SPARK-33071/SPARK-33536: Avoid changing dataset_id of LogicalPlan in join() " + @@ -237,8 +268,17 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { TestData(2, "personnel"), TestData(3, "develop")).toDS() val emp3 = emp1.join(emp2, emp1("key") === emp2("key")).select(emp1("*")) - assertAmbiguousSelfJoin(emp1.join(emp3, emp1.col("key") === emp3.col("key"), - "left_outer").select(emp1.col("*"), emp3.col("key").as("e2"))) + + assertCorrectResolution(emp1.join(emp3, emp1.col("key") === emp3.col("key")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + + val proj1 = emp1.join(emp3, emp1.col("key") === emp3.col("key"), + "left_outer").select(emp1.col("*"), emp3.col("key").as("e2")). + queryExecution.analyzed.asInstanceOf[Project] + val join1 = proj1.child.asInstanceOf[Join] + assert(proj1.projectList(0).references.subsetOf(join1.left.outputSet)) + assert(proj1.projectList(1).references.subsetOf(join1.left.outputSet)) + assert(proj1.projectList(2).references.subsetOf(join1.right.outputSet)) } test("df.show() should also not change dataset_id of LogicalPlan") { @@ -293,29 +333,52 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assert(col1DsId !== col2DsId) } - test("SPARK-35454: fail ambiguous self join - toDF") { - val df1 = spark.range(3).toDF() - val df2 = df1.filter($"id" > 0).toDF() - - withSQLConf( - SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", - SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - assertAmbiguousSelfJoin(df1.join(df2, df1.col("id") > df2.col("id"))) - } + test("SPARK-35454: Not an ambiguous self join - toDF") { + Seq(true, false).foreach(fail => { + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString, + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + val df1 = spark.range(3).toDF() + val df2 = df1.filter($"id" > 0).toDF() + assertCorrectResolution(df1.join(df2, df1.col("id") > df2.col("id")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + } + }) } test("SPARK-35454: fail ambiguous self join - join four tables") { val df1 = spark.range(3).select($"id".as("a"), $"id".as("b")) val df2 = df1.filter($"a" > 0).select("b") val df3 = df1.filter($"a" <= 2).select("b") - val df4 = df1.filter($"b" <= 2) + val df4 = df1.filter($"b" <= 2).as("temp") val df5 = spark.range(1) withSQLConf( SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "false", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - // `df2("b") < df4("b")` is always false - checkAnswer(df1.join(df2).join(df3).join(df4, df2("b") < df4("b")), Nil) + + // df4("b") is unambiguous + checkAnswer(df1.join(df2).join(df3).join(df4, df2("b") < df4("b")), + Seq( + Row(0, 0, 1, 0, 1, 1), + Row(0, 0, 1, 1, 1, 1), + Row(0, 0, 1, 2, 1, 1), + Row(0, 0, 2, 0, 1, 1), + Row(0, 0, 2, 1, 1, 1), + Row(0, 0, 2, 2, 1, 1), + Row(0, 0, 1, 0, 2, 2), + Row(0, 0, 1, 1, 2, 2), + Row(0, 0, 1, 2, 2, 2), + Row(0, 0, 2, 0, 2, 2), + Row(0, 0, 2, 1, 2, 2), + Row(0, 0, 2, 2, 2, 2), + Row(1, 1, 1, 0, 2, 2), + Row(1, 1, 1, 1, 2, 2), + Row(1, 1, 1, 2, 2, 2), + Row(1, 1, 2, 0, 2, 2), + Row(1, 1, 2, 1, 2, 2), + Row(1, 1, 2, 2, 2, 2) + )) // `df2("b")` actually points to the column of `df1`. checkAnswer( df1.join(df2).join(df5).join(df4).select(df2("b")), @@ -351,121 +414,146 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { test("SPARK-36874: DeduplicateRelations should copy dataset_id tag " + "to avoid ambiguous self join") { // Test for Project - val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value") - val df2 = df1.filter($"value" === "A2") - assertAmbiguousSelfJoin(df1.join(df2, df1("key1") === df2("key2"))) - assertAmbiguousSelfJoin(df2.join(df1, df1("key1") === df2("key2"))) - - // Test for SerializeFromObject - val df3 = spark.sparkContext.parallelize(1 to 10).map(x => (x, x)).toDF() - val df4 = df3.filter($"_1" <=> 0) - assertAmbiguousSelfJoin(df3.join(df4, df3("_1") === df4("_2"))) - assertAmbiguousSelfJoin(df4.join(df3, df3("_1") === df4("_2"))) - - // Test For Aggregate - val df5 = df1.groupBy($"key1").agg(count($"value") as "count") - val df6 = df5.filter($"key1" > 0) - assertAmbiguousSelfJoin(df5.join(df6, df5("key1") === df6("count"))) - assertAmbiguousSelfJoin(df6.join(df5, df5("key1") === df6("count"))) - - // Test for MapInPandas - val mapInPandasUDF = PythonUDF("mapInPandasUDF", null, - StructType(Seq(StructField("x", LongType), StructField("y", LongType))), - Seq.empty, - PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, - true) - val df7 = df1.mapInPandas(mapInPandasUDF) - val df8 = df7.filter($"x" > 0) - assertAmbiguousSelfJoin(df7.join(df8, df7("x") === df8("y"))) - assertAmbiguousSelfJoin(df8.join(df7, df7("x") === df8("y"))) - - // Test for FlatMapGroupsInPandas - val flatMapGroupsInPandasUDF = PythonUDF("flagMapGroupsInPandasUDF", null, - StructType(Seq(StructField("x", LongType), StructField("y", LongType))), - Seq.empty, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - true) - val df9 = df1.groupBy($"key1").flatMapGroupsInPandas(flatMapGroupsInPandasUDF) - val df10 = df9.filter($"x" > 0) - assertAmbiguousSelfJoin(df9.join(df10, df9("x") === df10("y"))) - assertAmbiguousSelfJoin(df10.join(df9, df9("x") === df10("y"))) - - // Test for FlatMapCoGroupsInPandas - val flatMapCoGroupsInPandasUDF = PythonUDF("flagMapCoGroupsInPandasUDF", null, - StructType(Seq(StructField("x", LongType), StructField("y", LongType))), - Seq.empty, - PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, - true) - val df11 = df1.groupBy($"key1").flatMapCoGroupsInPandas( - df1.groupBy($"key2"), flatMapCoGroupsInPandasUDF) - val df12 = df11.filter($"x" > 0) - assertAmbiguousSelfJoin(df11.join(df12, df11("x") === df12("y"))) - assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y"))) - - // Test for AttachDistributedSequence - val df13 = df1.withSequenceColumn("seq") - val df14 = df13.filter($"value" === "A2") - assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2"))) - assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2"))) - - // Test for Generate - // Ensure that the root of the plan is Generate - val df15 = Seq((1, Seq(1, 2, 3))).toDF("a", "intList").select($"a", explode($"intList")) - .queryExecution.optimizedPlan.find(_.isInstanceOf[Generate]).get.toDF() - val df16 = df15.filter($"a" > 0) - assertAmbiguousSelfJoin(df15.join(df16, df15("a") === df16("col"))) - assertAmbiguousSelfJoin(df16.join(df15, df15("a") === df16("col"))) - - // Test for Expand - // Ensure that the root of the plan is Expand - val df17 = - Expand( - Seq(Seq($"key1".expr, $"key2".expr)), - Seq( - AttributeReference("x", IntegerType)(), - AttributeReference("y", IntegerType)()), - df1.queryExecution.logical).toDF() - val df18 = df17.filter($"x" > 0) - assertAmbiguousSelfJoin(df17.join(df18, df17("x") === df18("y"))) - assertAmbiguousSelfJoin(df18.join(df17, df17("x") === df18("y"))) - - // Test for Window - val dfWithTS = spark.sql("SELECT timestamp'2021-10-15 01:52:00' time, 1 a, 2 b") - // Ensure that the root of the plan is Window - val df19 = WindowPlan( - Seq(Alias(dfWithTS("time").expr, "ts")()), - Seq(dfWithTS("a").expr), - Seq(SortOrder(dfWithTS("a").expr, Ascending)), - dfWithTS.queryExecution.logical).toDF() - val df20 = df19.filter($"a" > 0) - assertAmbiguousSelfJoin(df19.join(df20, df19("a") === df20("b"))) - assertAmbiguousSelfJoin(df20.join(df19, df19("a") === df20("b"))) - - // Test for ScriptTransformation - val ioSchema = - ScriptInputOutputSchema( - Seq(("TOK_TABLEROWFORMATFIELD", ","), - ("TOK_TABLEROWFORMATCOLLITEMS", "#"), - ("TOK_TABLEROWFORMATMAPKEYS", "@"), - ("TOK_TABLEROWFORMATNULL", "null"), - ("TOK_TABLEROWFORMATLINES", "\n")), - Seq(("TOK_TABLEROWFORMATFIELD", ","), - ("TOK_TABLEROWFORMATCOLLITEMS", "#"), - ("TOK_TABLEROWFORMATMAPKEYS", "@"), - ("TOK_TABLEROWFORMATNULL", "null"), - ("TOK_TABLEROWFORMATLINES", "\n")), None, None, - List.empty, List.empty, None, None, false) - // Ensure that the root of the plan is ScriptTransformation - val df21 = ScriptTransformation( - "cat", - Seq( - AttributeReference("x", IntegerType)(), - AttributeReference("y", IntegerType)()), - df1.queryExecution.logical, - ioSchema).toDF() - val df22 = df21.filter($"x" > 0) - assertAmbiguousSelfJoin(df21.join(df22, df21("x") === df22("y"))) - assertAmbiguousSelfJoin(df22.join(df21, df21("x") === df22("y"))) + Seq(true, false).foreach(fail => { + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString) { + val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value") + val df2 = df1.filter($"value" === "A2") + assertCorrectResolution(df1.join(df2, df1("key1") === df2("key2")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df2.join(df1, df1("key1") === df2("key2")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + + + // Test for SerializeFromObject + val df3 = spark.sparkContext.parallelize(1 to 10).map(x => (x, x)).toDF() + val df4 = df3.filter($"_1" <=> 0) + assertCorrectResolution(df3.join(df4, df3("_1") === df4("_2")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df4.join(df3, df3("_1") === df4("_2")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + + // Test For Aggregate + val df5 = df1.groupBy($"key1").agg(count($"value") as "count") + val df6 = df5.filter($"key1" > 0) + assertCorrectResolution(df5.join(df6, df5("key1") === df6("count")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df6.join(df5, df5("key1") === df6("count")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + + // Test for MapInPandas + val mapInPandasUDF = PythonUDF("mapInPandasUDF", null, + StructType(Seq(StructField("x", LongType), StructField("y", LongType))), + Seq.empty, + PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, + true) + val df7 = df1.mapInPandas(mapInPandasUDF) + val df8 = df7.filter($"x" > 0) + assertCorrectResolution(df7.join(df8, df7("x") === df8("y")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df8.join(df7, df7("x") === df8("y")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + + // Test for FlatMapGroupsInPandas + val flatMapGroupsInPandasUDF = PythonUDF("flagMapGroupsInPandasUDF", null, + StructType(Seq(StructField("x", LongType), StructField("y", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true) + val df9 = df1.groupBy($"key1").flatMapGroupsInPandas(flatMapGroupsInPandasUDF) + val df10 = df9.filter($"x" > 0) + assertCorrectResolution(df9.join(df10, df9("x") === df10("y")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df10.join(df9, df9("x") === df10("y")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + + // Test for FlatMapCoGroupsInPandas + val flatMapCoGroupsInPandasUDF = PythonUDF("flagMapCoGroupsInPandasUDF", null, + StructType(Seq(StructField("x", LongType), StructField("y", LongType))), + Seq.empty, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + true) + val df11 = df1.groupBy($"key1").flatMapCoGroupsInPandas( + df1.groupBy($"key2"), flatMapCoGroupsInPandasUDF) + val df12 = df11.filter($"x" > 0) + assertCorrectResolution(df11.join(df12, df11("x") === df12("y")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df12.join(df11, df11("x") === df12("y")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + + // Test for AttachDistributedSequence + val df13 = df1.withSequenceColumn("seq") + val df14 = df13.filter($"value" === "A2") + assertCorrectResolution(df13.join(df14, df13("key1") === df14("key2")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df14.join(df13, df13("key1") === df14("key2")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + // Test for Generate + // Ensure that the root of the plan is Generate + val df15 = Seq((1, Seq(1, 2, 3))).toDF("a", "intList").select($"a", explode($"intList")) + .queryExecution.optimizedPlan.find(_.isInstanceOf[Generate]).get.toDF() + val df16 = df15.filter($"a" > 0) + assertCorrectResolution(df15.join(df16, df15("a") === df16("col")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df16.join(df15, df15("a") === df16("col")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + + // Test for Expand + // Ensure that the root of the plan is Expand + val df17 = + Expand( + Seq(Seq($"key1".expr, $"key2".expr)), + Seq( + AttributeReference("x", IntegerType)(), + AttributeReference("y", IntegerType)()), + df1.queryExecution.logical).toDF() + val df18 = df17.filter($"x" > 0) + assertCorrectResolution(df17.join(df18, df17("x") === df18("y")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df18.join(df17, df17("x") === df18("y")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + // Test for Window + val dfWithTS = spark.sql("SELECT timestamp'2021-10-15 01:52:00' time, 1 a, 2 b") + // Ensure that the root of the plan is Window + val df19 = WindowPlan( + Seq(Alias(dfWithTS("time").expr, "ts")()), + Seq(dfWithTS("a").expr), + Seq(SortOrder(dfWithTS("a").expr, Ascending)), + dfWithTS.queryExecution.logical).toDF() + val df20 = df19.filter($"a" > 0) + assertCorrectResolution(df19.join(df20, df19("a") === df20("b")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df20.join(df19, df19("a") === df20("b")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + // Test for ScriptTransformation + val ioSchema = + ScriptInputOutputSchema( + Seq(("TOK_TABLEROWFORMATFIELD", ","), + ("TOK_TABLEROWFORMATCOLLITEMS", "#"), + ("TOK_TABLEROWFORMATMAPKEYS", "@"), + ("TOK_TABLEROWFORMATNULL", "null"), + ("TOK_TABLEROWFORMATLINES", "\n")), + Seq(("TOK_TABLEROWFORMATFIELD", ","), + ("TOK_TABLEROWFORMATCOLLITEMS", "#"), + ("TOK_TABLEROWFORMATMAPKEYS", "@"), + ("TOK_TABLEROWFORMATNULL", "null"), + ("TOK_TABLEROWFORMATLINES", "\n")), None, None, + List.empty, List.empty, None, None, false) + // Ensure that the root of the plan is ScriptTransformation + val df21 = ScriptTransformation( + "cat", + Seq( + AttributeReference("x", IntegerType)(), + AttributeReference("y", IntegerType)()), + df1.queryExecution.logical, + ioSchema).toDF() + val df22 = df21.filter($"x" > 0) + assertCorrectResolution(df21.join(df22, df21("x") === df22("y")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df22.join(df21, df21("x") === df22("y")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + } + }) } test("SPARK-35937: GetDateFieldOperations should skip unresolved nodes") { @@ -498,4 +586,85 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assert(df1.join(df2, $"t1.i" === $"t2.i").cache().count() == 1) } } + + test("SPARK-47217: deduplication of project causes ambiguity in resolution") { + Seq(true, false).foreach(fail => { + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString) { + val df = Seq((1, 2)).toDF("a", "b") + val df2 = df.select(df("a").as("aa"), df("b").as("bb")) + val df3 = df2.join(df, df2("bb") === df("b")).select(df2("aa"), df("a")) + checkAnswer( + df3, + Row(1, 1) :: Nil) + } + }) + } + + test("SPARK-47217: deduplication in nested joins with join attribute aliased") { + Seq(true, false).foreach(fail => { + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString) { + val df1 = Seq((1, 2)).toDF("a", "b") + val df2 = Seq((1, 2)).toDF("aa", "bb") + val df1Joindf2 = df1.join(df2, df1("a") === df2("aa")).select(df1("a").as("aaa"), + df2("aa"), df1("b")) + + assertCorrectResolution(df1Joindf2.join(df1, df1Joindf2("aaa") === df1("a")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + + assertCorrectResolution(df1.join(df1Joindf2, df1Joindf2("aaa") === df1("a")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + + val proj1 = df1Joindf2.join(df1, df1Joindf2("aaa") === df1("a")).select(df1Joindf2("aa"), + df1("a")).queryExecution.analyzed.asInstanceOf[Project] + val join1 = proj1.child.asInstanceOf[Join] + assert(proj1.projectList(0).references.subsetOf(join1.left.outputSet)) + assert(proj1.projectList(1).references.subsetOf(join1.right.outputSet)) + + val proj2 = df1.join(df1Joindf2, df1Joindf2("aaa") === df1("a")).select(df1Joindf2("aa"), + df1("a")).queryExecution.analyzed.asInstanceOf[Project] + val join2 = proj2.child.asInstanceOf[Join] + assert(proj2.projectList(0).references.subsetOf(join2.right.outputSet)) + assert(proj2.projectList(1).references.subsetOf(join2.left.outputSet)) + } + }) + } + + test("SPARK-47217: deduplication in nested joins without join attribute aliased") { + Seq(true, false).foreach(fail => { + withSQLConf( + SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> fail.toString) { + val df1 = Seq((1, 2)).toDF("a", "b") + val df2 = Seq((1, 2)).toDF("aa", "bb") + val df1Joindf2 = df1.join(df2, df1("a") === df2("aa")).select(df1("a"), df2("aa"), df1("b")) + + assertCorrectResolution(df1Joindf2.join(df1, df1Joindf2("a") === df1("a")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + + assertCorrectResolution(df1.join(df1Joindf2, df1Joindf2("a") === df1("a")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + + val proj1 = df1Joindf2.join(df1, df1Joindf2("a") === df1("a")).select(df1Joindf2("a"), + df1("a")).queryExecution.analyzed.asInstanceOf[Project] + val join1 = proj1.child.asInstanceOf[Join] + assert(proj1.projectList(0).references.subsetOf(join1.left.outputSet)) + assert(proj1.projectList(1).references.subsetOf(join1.right.outputSet)) + + val proj2 = df1.join(df1Joindf2, df1Joindf2("a") === df1("a")).select(df1Joindf2("a"), + df1("a")).queryExecution.analyzed.asInstanceOf[Project] + val join2 = proj2.child.asInstanceOf[Join] + assert(proj2.projectList(0).references.subsetOf(join2.right.outputSet)) + assert(proj2.projectList(1).references.subsetOf(join2.left.outputSet)) + } + }) + } } + +object Resolution extends Enumeration { + type Resolution = Value + + val LeftConditionToLeftLeg, LeftConditionToRightLeg, RightConditionToRightLeg, + RightConditionToLeftLeg = Value +} +