diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 8ea50e2ceb65..4aac2c6c7067 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -134,6 +134,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { expr: Expression, resolveColumnByName: Seq[String] => Option[Expression], getAttrCandidates: () => Seq[Attribute], + resolveOnDatasetId: (Long, String) => Option[NamedExpression], throws: Boolean, includeLastResort: Boolean): Expression = { def innerResolve(e: Expression, isTopLevel: Boolean): Expression = withOrigin(e.origin) { @@ -156,6 +157,9 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { } matched(ordinal) + case u @ UnresolvedAttributeWithTag(attr, id) => + resolveOnDatasetId(id, attr.name).getOrElse(attr) + case u @ UnresolvedAttribute(nameParts) => val result = withPosition(u) { resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map { @@ -452,6 +456,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { plan.resolve(nameParts, conf.resolver) }, getAttrCandidates = () => plan.output, + resolveOnDatasetId = (_, _) => None, throws = throws, includeLastResort = includeLastResort) } @@ -477,6 +482,57 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { assert(q.children.length == 1) q.children.head.output }, + + resolveOnDatasetId = (datasetid: Long, name: String) => { + def findUnaryNodeMatchingTagId(lp: LogicalPlan): Option[(LogicalPlan, Int)] = { + var currentLp = lp + var depth = 0 + while(true) { + if (currentLp.getTagValue(LogicalPlan.DATASET_ID_TAG).exists(_.contains(datasetid))) { + return Option(currentLp, depth) + } else { + if (currentLp.children.size == 1) { + currentLp = currentLp.children.head + } else { + // leaf node or node is a binary node + return None + } + } + depth += 1 + } + None + } + + val binaryNodeOpt = q.collectFirst { + case bn: BinaryNode => bn + } + + val resolveOnAttribs = binaryNodeOpt match { + case Some(bn) => + val leftDefOpt = findUnaryNodeMatchingTagId(bn.left) + val rightDefOpt = findUnaryNodeMatchingTagId(bn.right) + (leftDefOpt, rightDefOpt) match { + + case (None, Some((lp, _))) => lp.output + + case (Some((lp, _)), None) => lp.output + + case (Some((lp1, depth1)), Some((lp2, depth2))) => + if (depth1 == depth2) { + q.children.head.output + } else if (depth1 < depth2) { + lp1.output + } else { + lp2.output + } + + case _ => q.children.head.output + } + + case _ => q.children.head.output + } + AttributeSeq.fromNormalOutput(resolveOnAttribs).resolve(Seq(name), conf.resolver) + }, throws = true, includeLastResort = includeLastResort) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 7a3cc4bc8e83..397351e0c1fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -268,6 +268,47 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un } } +case class UnresolvedAttributeWithTag(attribute: Attribute, datasetId: Long) extends Attribute with + Unevaluable { + def name: String = attribute.name + + override def exprId: ExprId = throw new UnresolvedException("exprId") + + override def dataType: DataType = throw new UnresolvedException("dataType") + + override def nullable: Boolean = throw new UnresolvedException("nullable") + + override def qualifier: Seq[String] = throw new UnresolvedException("qualifier") + + override lazy val resolved = false + + override def newInstance(): UnresolvedAttributeWithTag = this + + override def withNullability(newNullability: Boolean): UnresolvedAttributeWithTag = this + + override def withQualifier(newQualifier: Seq[String]): UnresolvedAttributeWithTag = this + + override def withName(newName: String): UnresolvedAttributeWithTag = this + + override def withMetadata(newMetadata: Metadata): Attribute = this + + override def withExprId(newExprId: ExprId): UnresolvedAttributeWithTag = this + + override def withDataType(newType: DataType): Attribute = this + + final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_ATTRIBUTE) + + override def toString: String = s"'$name" + + override def sql: String = attribute.sql + + /** + * Returns true if this matches the token. This requires the attribute to only have one part in + * its name and that matches the given token in a case insensitive way. + */ + def equalsIgnoreCase(token: String): Boolean = token.equalsIgnoreCase(attribute.name) +} + object UnresolvedAttribute extends AttributeNameParser { /** * Creates an [[UnresolvedAttribute]], parsing segments separated by dots ('.'). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e1121d1f9026..a9b130c981ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical +import scala.collection.mutable + import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException @@ -30,7 +32,6 @@ import org.apache.spark.sql.catalyst.util.MetadataColumnHelper import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{MapType, StructType} - abstract class LogicalPlan extends QueryPlan[LogicalPlan] with AnalysisHelper @@ -199,6 +200,7 @@ object LogicalPlan { // to the old code path. private[spark] val PLAN_ID_TAG = TreeNodeTag[Long]("plan_id") private[spark] val IS_METADATA_COL = TreeNodeTag[Unit]("is_metadata_col") + private[spark] val DATASET_ID_TAG = TreeNodeTag[mutable.HashSet[Long]]("dataset_id") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 189be1d6a30d..0f15fcf51b8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -38,6 +38,7 @@ import org.apache.spark.api.r.RRDD import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.resource.ResourceProfile +import org.apache.spark.sql.Dataset.DATASET_ID_KEY import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QueryPlanningTracker, ScalaReflection, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation @@ -47,7 +48,7 @@ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern} +import org.apache.spark.sql.catalyst.trees.TreePattern import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId @@ -73,7 +74,7 @@ private[sql] object Dataset { val curId = new java.util.concurrent.atomic.AtomicLong() val DATASET_ID_KEY = "__dataset_id" val COL_POS_KEY = "__col_position" - val DATASET_ID_TAG = TreeNodeTag[HashSet[Long]]("dataset_id") + val DATASET_ID_TAG = LogicalPlan.DATASET_ID_TAG def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) @@ -1150,10 +1151,10 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. + val plan = withPlan( - Join(logicalPlan, right.logicalPlan, - JoinType(joinType), joinExprs.map(_.expr), JoinHint.NONE)) - .queryExecution.analyzed.asInstanceOf[Join] + tryAmbiguityResolution(right, joinExprs, joinType) + ).queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) { @@ -1174,6 +1175,32 @@ class Dataset[T] private[sql]( JoinWith.resolveSelfJoinCondition(sparkSession.sessionState.analyzer.resolver, plan) } + private def tryAmbiguityResolution( + right: Dataset[_], + joinExprs: Option[Column], + joinType: String) = { + val planPart1 = withPlan( + Join(logicalPlan, right.logicalPlan, + JoinType(joinType), None, JoinHint.NONE)) + .queryExecution.analyzed.asInstanceOf[Join] + val inputSet = planPart1.outputSet + val joinExprsRectified = joinExprs.map(_.expr transformUp { + case attr: AttributeReference if attr.metadata.contains(Dataset.DATASET_ID_KEY) => + val attribTagId = attr.metadata.getLong(Dataset.DATASET_ID_KEY) + val leftTagIdMap = planPart1.left.getTagValue(LogicalPlan.DATASET_ID_TAG) + val rightTagIdMap = planPart1.right.getTagValue(LogicalPlan.DATASET_ID_TAG) + if (!inputSet.contains(attr) || + (planPart1.left.outputSet.contains(attr) && !leftTagIdMap.contains(attribTagId)) || + (planPart1.right.outputSet.contains(attr) && !rightTagIdMap.contains(attribTagId))) { + UnresolvedAttributeWithTag(attr, attribTagId) + } else { + attr + } + }) + + Join(planPart1.left, planPart1.right, JoinType(joinType), joinExprsRectified, JoinHint.NONE) + } + /** * Join with another `DataFrame`, using the given join expression. The following performs * a full outer join between `df1` and `df2`. @@ -1308,12 +1335,20 @@ class Dataset[T] private[sql]( case a: AttributeReference if logicalPlan.outputSet.contains(a) => val index = logicalPlan.output.indexWhere(_.exprId == a.exprId) joined.left.output(index) + + case a: AttributeReference if a.metadata.contains(Dataset.DATASET_ID_KEY) => + UnresolvedAttributeWithTag(a, a.metadata.getLong(Dataset.DATASET_ID_KEY)) } + val rightAsOfExpr = rightAsOf.expr.transformUp { case a: AttributeReference if other.logicalPlan.outputSet.contains(a) => val index = other.logicalPlan.output.indexWhere(_.exprId == a.exprId) joined.right.output(index) + + case a: AttributeReference if a.metadata.contains(Dataset.DATASET_ID_KEY) => + UnresolvedAttributeWithTag(a, a.metadata.getLong(Dataset.DATASET_ID_KEY)) } + withPlan { AsOfJoin( joined.left, joined.right, @@ -1576,7 +1611,52 @@ class Dataset[T] private[sql]( case other => other } - Project(untypedCols.map(_.named), logicalPlan) + val namedExprs = untypedCols.map(_.named) + val inputSet = logicalPlan.outputSet + val rectifiedNamedExprs = namedExprs.map(ne => ne match { + + case al: Alias if !al.references.subsetOf(inputSet) || al.references.exists(attr => + attr.metadata.contains(DATASET_ID_KEY) && attr.metadata.getLong(DATASET_ID_KEY) != + inputSet.find(_.canonicalized == attr.canonicalized).map(x => + if (x.metadata.contains(DATASET_ID_KEY)) { + x.metadata.getLong(DATASET_ID_KEY) + } else { + Dataset.this.id + }).get) + => + val unresolvedExpr = al.child.transformUp { + case attr: AttributeReference if attr.metadata.contains(Dataset.DATASET_ID_KEY) && + (!inputSet.contains(attr) || attr.metadata.getLong(DATASET_ID_KEY) != + inputSet.find(_.canonicalized == attr.canonicalized).map(x => + if (x.metadata.contains(DATASET_ID_KEY)) { + x.metadata.getLong(DATASET_ID_KEY) + } else { + Dataset.this.id + }).get) + => + UnresolvedAttributeWithTag(attr, attr.metadata.getLong(Dataset.DATASET_ID_KEY)) + } + val newAl = al.copy(child = unresolvedExpr, name = al.name)(exprId = al.exprId, + qualifier = al.qualifier, explicitMetadata = al.explicitMetadata, + nonInheritableMetadataKeys = al.nonInheritableMetadataKeys) + newAl.copyTagsFrom(al) + newAl + + case attr: Attribute if attr.metadata.contains(Dataset.DATASET_ID_KEY) && + (!inputSet.contains(attr) || attr.metadata.getLong(DATASET_ID_KEY) != + inputSet.find(_.canonicalized == attr.canonicalized).map(x => + if (x.metadata.contains(DATASET_ID_KEY)) { + x.metadata.getLong(DATASET_ID_KEY) + } else { + Dataset.this.id + }).get) + => + UnresolvedAttributeWithTag(attr, attr.metadata.getLong(Dataset.DATASET_ID_KEY)) + + case _ => ne + + }) + Project(rectifiedNamedExprs, logicalPlan) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala index 280eb095dc75..ec80c782b5b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.jdk.CollectionConverters._ +import org.apache.spark.sql.catalyst.plans.logical.AsOfJoin import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSparkSession @@ -173,4 +174,23 @@ class DataFrameAsOfJoinSuite extends QueryTest ) ) } + + test("SPARK_47217: Dedup of relations can impact projected columns resolution") { + val (df1, df2) = prepareForAsOfJoin() + val join1 = df1.join(df2, df1.col("a") === df2.col("a")).select(df2.col("a"), df1.col("b"), + df2.col("b"), df1.col("a").as("aa")) + + // In stock spark this would throw ambiguous column exception, even though it is not ambiguous + val asOfjoin2 = join1.joinAsOf( + df1, df1.col("a"), join1.col("a"), usingColumns = Seq.empty, + joinType = "left", tolerance = null, allowExactMatches = false, direction = "nearest") + + asOfjoin2.queryExecution.assertAnalyzed() + + val testDf = asOfjoin2.select(df1.col("a")) + val analyzed = testDf.queryExecution.analyzed + val attributeRefToCheck = analyzed.output.head + assert(analyzed.children(0).asInstanceOf[AsOfJoin].right.outputSet. + contains(attributeRefToCheck)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 7dc40549a17b..4fccf9d2415c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql import org.apache.spark.api.python.PythonEvalType -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder} -import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, BinaryExpression, PythonUDF, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, Join, Project, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.{count, explode, sum, year} import org.apache.spark.sql.internal.SQLConf @@ -97,7 +97,28 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assert(e.message.contains("ambiguous")) } - test("SPARK-28344: fail ambiguous self join - column ref in join condition") { + private def assertCorrectResolution( + df: => DataFrame, + leftResolution: Resolution.Resolution, + rightResolution: Resolution.Resolution): Unit = { + val join = df.queryExecution.analyzed.asInstanceOf[Join] + val binaryCondition = join.condition.get.asInstanceOf[BinaryExpression] + leftResolution match { + case Resolution.LeftConditionToLeftLeg => + assert(join.left.outputSet.contains(binaryCondition.left.references.head)) + case Resolution.LeftConditionToRightLeg => + assert(join.right.outputSet.contains(binaryCondition.left.references.head)) + } + + rightResolution match { + case Resolution.RightConditionToLeftLeg => + assert(join.left.outputSet.contains(binaryCondition.right.references.head)) + case Resolution.RightConditionToRightLeg => + assert(join.right.outputSet.contains(binaryCondition.right.references.head)) + } + } + + test("SPARK-28344: NOT AN ambiguous self join - column ref in join condition") { val df1 = spark.range(3) val df2 = df1.filter($"id" > 0) @@ -118,29 +139,32 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { withSQLConf( SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - assertAmbiguousSelfJoin(df1.join(df2, df1("id") > df2("id"))) + assertCorrectResolution(df1.join(df2, df1("id") > df2("id")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) } } - test("SPARK-28344: fail ambiguous self join - Dataset.colRegex as column ref") { + test("SPARK-28344: Not AN ambiguous self join - Dataset.colRegex as column ref") { val df1 = spark.range(3) val df2 = df1.filter($"id" > 0) withSQLConf( SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - assertAmbiguousSelfJoin(df1.join(df2, df1.colRegex("id") > df2.colRegex("id"))) + assertCorrectResolution(df1.join(df2, df1.colRegex("id") > df2.colRegex("id")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) } } - test("SPARK-28344: fail ambiguous self join - Dataset.col with nested field") { + test("SPARK-28344: Not An ambiguous self join - Dataset.col with nested field") { val df1 = spark.read.json(Seq("""{"a": {"b": 1, "c": 1}}""").toDS()) val df2 = df1.filter($"a.b" > 0) withSQLConf( SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - assertAmbiguousSelfJoin(df1.join(df2, df1("a.b") > df2("a.c"))) + assertCorrectResolution( df1.join(df2, df1("a.b") > df2("a.c")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) } } @@ -165,7 +189,9 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { withSQLConf( SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - assertAmbiguousSelfJoin(df1.join(df2).select(df2("id"))) + val proj1 = df1.join(df2).select(df2("id")).queryExecution.analyzed.asInstanceOf[Project] + val join1 = proj1.child.asInstanceOf[Join] + assert(proj1.projectList(0).references.subsetOf(join1.right.outputSet)) } } @@ -205,7 +231,11 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { assertAmbiguousSelfJoin(df1.join(df2).join(df3, df2("id") < df3("id"))) - assertAmbiguousSelfJoin(df1.join(df4).join(df2).select(df2("id"))) + + val proj1 = df1.join(df4).join(df2).select(df2("id")).queryExecution.analyzed. + asInstanceOf[Project] + val join1 = proj1.child.asInstanceOf[Join] + assert(proj1.projectList(0).references.subsetOf(join1.right.outputSet)) } } @@ -237,8 +267,17 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { TestData(2, "personnel"), TestData(3, "develop")).toDS() val emp3 = emp1.join(emp2, emp1("key") === emp2("key")).select(emp1("*")) - assertAmbiguousSelfJoin(emp1.join(emp3, emp1.col("key") === emp3.col("key"), - "left_outer").select(emp1.col("*"), emp3.col("key").as("e2"))) + + assertCorrectResolution(emp1.join(emp3, emp1.col("key") === emp3.col("key")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + + val proj1 = emp1.join(emp3, emp1.col("key") === emp3.col("key"), + "left_outer").select(emp1.col("*"), emp3.col("key").as("e2")). + queryExecution.analyzed.asInstanceOf[Project] + val join1 = proj1.child.asInstanceOf[Join] + assert(proj1.projectList(0).references.subsetOf(join1.left.outputSet)) + assert(proj1.projectList(1).references.subsetOf(join1.left.outputSet)) + assert(proj1.projectList(2).references.subsetOf(join1.right.outputSet)) } test("df.show() should also not change dataset_id of LogicalPlan") { @@ -293,14 +332,15 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assert(col1DsId !== col2DsId) } - test("SPARK-35454: fail ambiguous self join - toDF") { + test("SPARK-35454: Not an ambiguous self join - toDF") { val df1 = spark.range(3).toDF() val df2 = df1.filter($"id" > 0).toDF() withSQLConf( SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true", SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - assertAmbiguousSelfJoin(df1.join(df2, df1.col("id") > df2.col("id"))) + assertCorrectResolution(df1.join(df2, df1.col("id") > df2.col("id")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) } } @@ -351,22 +391,30 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { test("SPARK-36874: DeduplicateRelations should copy dataset_id tag " + "to avoid ambiguous self join") { // Test for Project + val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value") val df2 = df1.filter($"value" === "A2") - assertAmbiguousSelfJoin(df1.join(df2, df1("key1") === df2("key2"))) - assertAmbiguousSelfJoin(df2.join(df1, df1("key1") === df2("key2"))) + assertCorrectResolution(df1.join(df2, df1("key1") === df2("key2")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df2.join(df1, df1("key1") === df2("key2")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + // Test for SerializeFromObject val df3 = spark.sparkContext.parallelize(1 to 10).map(x => (x, x)).toDF() val df4 = df3.filter($"_1" <=> 0) - assertAmbiguousSelfJoin(df3.join(df4, df3("_1") === df4("_2"))) - assertAmbiguousSelfJoin(df4.join(df3, df3("_1") === df4("_2"))) + assertCorrectResolution(df3.join(df4, df3("_1") === df4("_2")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df4.join(df3, df3("_1") === df4("_2")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test For Aggregate val df5 = df1.groupBy($"key1").agg(count($"value") as "count") val df6 = df5.filter($"key1" > 0) - assertAmbiguousSelfJoin(df5.join(df6, df5("key1") === df6("count"))) - assertAmbiguousSelfJoin(df6.join(df5, df5("key1") === df6("count"))) + assertCorrectResolution(df5.join(df6, df5("key1") === df6("count")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df6.join(df5, df5("key1") === df6("count")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test for MapInPandas val mapInPandasUDF = PythonUDF("mapInPandasUDF", null, @@ -376,8 +424,10 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { true) val df7 = df1.mapInPandas(mapInPandasUDF) val df8 = df7.filter($"x" > 0) - assertAmbiguousSelfJoin(df7.join(df8, df7("x") === df8("y"))) - assertAmbiguousSelfJoin(df8.join(df7, df7("x") === df8("y"))) + assertCorrectResolution(df7.join(df8, df7("x") === df8("y")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df8.join(df7, df7("x") === df8("y")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test for FlatMapGroupsInPandas val flatMapGroupsInPandasUDF = PythonUDF("flagMapGroupsInPandasUDF", null, @@ -387,8 +437,10 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { true) val df9 = df1.groupBy($"key1").flatMapGroupsInPandas(flatMapGroupsInPandasUDF) val df10 = df9.filter($"x" > 0) - assertAmbiguousSelfJoin(df9.join(df10, df9("x") === df10("y"))) - assertAmbiguousSelfJoin(df10.join(df9, df9("x") === df10("y"))) + assertCorrectResolution(df9.join(df10, df9("x") === df10("y")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df10.join(df9, df9("x") === df10("y")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test for FlatMapCoGroupsInPandas val flatMapCoGroupsInPandasUDF = PythonUDF("flagMapCoGroupsInPandasUDF", null, @@ -399,22 +451,27 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { val df11 = df1.groupBy($"key1").flatMapCoGroupsInPandas( df1.groupBy($"key2"), flatMapCoGroupsInPandasUDF) val df12 = df11.filter($"x" > 0) - assertAmbiguousSelfJoin(df11.join(df12, df11("x") === df12("y"))) - assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y"))) + assertCorrectResolution(df11.join(df12, df11("x") === df12("y")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df12.join(df11, df11("x") === df12("y")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test for AttachDistributedSequence val df13 = df1.withSequenceColumn("seq") val df14 = df13.filter($"value" === "A2") - assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2"))) - assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2"))) - + assertCorrectResolution(df13.join(df14, df13("key1") === df14("key2")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df14.join(df13, df13("key1") === df14("key2")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test for Generate // Ensure that the root of the plan is Generate val df15 = Seq((1, Seq(1, 2, 3))).toDF("a", "intList").select($"a", explode($"intList")) .queryExecution.optimizedPlan.find(_.isInstanceOf[Generate]).get.toDF() val df16 = df15.filter($"a" > 0) - assertAmbiguousSelfJoin(df15.join(df16, df15("a") === df16("col"))) - assertAmbiguousSelfJoin(df16.join(df15, df15("a") === df16("col"))) + assertCorrectResolution(df15.join(df16, df15("a") === df16("col")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df16.join(df15, df15("a") === df16("col")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test for Expand // Ensure that the root of the plan is Expand @@ -426,9 +483,10 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { AttributeReference("y", IntegerType)()), df1.queryExecution.logical).toDF() val df18 = df17.filter($"x" > 0) - assertAmbiguousSelfJoin(df17.join(df18, df17("x") === df18("y"))) - assertAmbiguousSelfJoin(df18.join(df17, df17("x") === df18("y"))) - + assertCorrectResolution(df17.join(df18, df17("x") === df18("y")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df18.join(df17, df17("x") === df18("y")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test for Window val dfWithTS = spark.sql("SELECT timestamp'2021-10-15 01:52:00' time, 1 a, 2 b") // Ensure that the root of the plan is Window @@ -438,9 +496,10 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { Seq(SortOrder(dfWithTS("a").expr, Ascending)), dfWithTS.queryExecution.logical).toDF() val df20 = df19.filter($"a" > 0) - assertAmbiguousSelfJoin(df19.join(df20, df19("a") === df20("b"))) - assertAmbiguousSelfJoin(df20.join(df19, df19("a") === df20("b"))) - + assertCorrectResolution(df19.join(df20, df19("a") === df20("b")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df20.join(df19, df19("a") === df20("b")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) // Test for ScriptTransformation val ioSchema = ScriptInputOutputSchema( @@ -464,8 +523,10 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { df1.queryExecution.logical, ioSchema).toDF() val df22 = df21.filter($"x" > 0) - assertAmbiguousSelfJoin(df21.join(df22, df21("x") === df22("y"))) - assertAmbiguousSelfJoin(df22.join(df21, df21("x") === df22("y"))) + assertCorrectResolution(df21.join(df22, df21("x") === df22("y")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + assertCorrectResolution(df22.join(df21, df21("x") === df22("y")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) } test("SPARK-35937: GetDateFieldOperations should skip unresolved nodes") { @@ -498,4 +559,70 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assert(df1.join(df2, $"t1.i" === $"t2.i").cache().count() == 1) } } + + test("SPARK_47217: deduplication of project causes ambiguity in resolution") { + val df = Seq((1, 2)).toDF("a", "b") + val df2 = df.select(df("a").as("aa"), df("b").as("bb")) + val df3 = df2.join(df, df2("bb") === df("b")).select(df2("aa"), df("a")) + checkAnswer( + df3, + Row(1, 1) :: Nil) + } + + test("SPARK-47217. deduplication in nested joins with join attribute aliased") { + val df1 = Seq((1, 2)).toDF("a", "b") + val df2 = Seq((1, 2)).toDF("aa", "bb") + val df1Joindf2 = df1.join(df2, df1("a") === df2("aa")).select(df1("a").as("aaa"), + df2("aa"), df1("b")) + + assertCorrectResolution(df1Joindf2.join(df1, df1Joindf2("aaa") === df1("a")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + + assertCorrectResolution(df1.join(df1Joindf2, df1Joindf2("aaa") === df1("a")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + + val proj1 = df1Joindf2.join(df1, df1Joindf2("aaa") === df1("a")).select(df1Joindf2("aa"), + df1("a")).queryExecution.analyzed.asInstanceOf[Project] + val join1 = proj1.child.asInstanceOf[Join] + assert(proj1.projectList(0).references.subsetOf(join1.left.outputSet)) + assert(proj1.projectList(1).references.subsetOf(join1.right.outputSet)) + + val proj2 = df1.join(df1Joindf2, df1Joindf2("aaa") === df1("a")).select(df1Joindf2("aa"), + df1("a")).queryExecution.analyzed.asInstanceOf[Project] + val join2 = proj2.child.asInstanceOf[Join] + assert(proj2.projectList(0).references.subsetOf(join2.right.outputSet)) + assert(proj2.projectList(1).references.subsetOf(join2.left.outputSet)) + } + + test("SPARK-47217. deduplication in nested joins without join attribute aliased") { + val df1 = Seq((1, 2)).toDF("a", "b") + val df2 = Seq((1, 2)).toDF("aa", "bb") + val df1Joindf2 = df1.join(df2, df1("a") === df2("aa")).select(df1("a"), df2("aa"), df1("b")) + + assertCorrectResolution(df1Joindf2.join(df1, df1Joindf2("a") === df1("a")), + Resolution.LeftConditionToLeftLeg, Resolution.RightConditionToRightLeg) + + assertCorrectResolution(df1.join(df1Joindf2, df1Joindf2("a") === df1("a")), + Resolution.LeftConditionToRightLeg, Resolution.RightConditionToLeftLeg) + + val proj1 = df1Joindf2.join(df1, df1Joindf2("a") === df1("a")).select(df1Joindf2("a"), + df1("a")).queryExecution.analyzed.asInstanceOf[Project] + val join1 = proj1.child.asInstanceOf[Join] + assert(proj1.projectList(0).references.subsetOf(join1.left.outputSet)) + assert(proj1.projectList(1).references.subsetOf(join1.right.outputSet)) + + val proj2 = df1.join(df1Joindf2, df1Joindf2("a") === df1("a")).select(df1Joindf2("a"), + df1("a")).queryExecution.analyzed.asInstanceOf[Project] + val join2 = proj2.child.asInstanceOf[Join] + assert(proj2.projectList(0).references.subsetOf(join2.right.outputSet)) + assert(proj2.projectList(1).references.subsetOf(join2.left.outputSet)) + } } + +object Resolution extends Enumeration { + type Resolution = Value + + val LeftConditionToLeftLeg, LeftConditionToRightLeg, RightConditionToRightLeg, + RightConditionToLeftLeg = Value +} +