-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-27547][SQL] Fix DataFrame self-join problems #24442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bb0402a
5a8886b
7abdf03
13b23d6
33fa81d
5b2c0ac
8d4cea6
e7bfcc8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,6 +53,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation | |
| import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, FileTable} | ||
| import org.apache.spark.sql.execution.python.EvaluatePython | ||
| import org.apache.spark.sql.execution.stat.StatFunctions | ||
| import org.apache.spark.sql.internal.SQLConf | ||
| import org.apache.spark.sql.streaming.DataStreamWriter | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.sql.util.SchemaUtils | ||
|
|
@@ -62,6 +63,11 @@ import org.apache.spark.unsafe.types.CalendarInterval | |
| import org.apache.spark.util.Utils | ||
|
|
||
| private[sql] object Dataset { | ||
| val curId = new java.util.concurrent.atomic.AtomicLong() | ||
|
|
||
| val ID_PREFIX = "__dataset_id" | ||
| val COL_POS_PREFIX = "__col_position" | ||
|
|
||
| def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { | ||
| val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) | ||
| // Eagerly bind the encoder so we verify that the encoder matches the underlying | ||
|
|
@@ -183,6 +189,9 @@ class Dataset[T] private[sql]( | |
| @DeveloperApi @Unstable @transient val encoder: Encoder[T]) | ||
| extends Serializable { | ||
|
|
||
| // A globally unique id for this Dataset. | ||
| private val id = Dataset.curId.getAndIncrement() | ||
|
|
||
| queryExecution.assertAnalyzed() | ||
|
|
||
| // Note for Spark contributors: if adding or updating any action in `Dataset`, please make sure | ||
|
|
@@ -873,7 +882,25 @@ class Dataset[T] private[sql]( | |
| * @since 2.0.0 | ||
| */ | ||
| def join(right: Dataset[_]): DataFrame = withPlan { | ||
| Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE) | ||
| val (joinLeft, joinRight) = prepareJoinPlan(this, right) | ||
| Join(joinLeft, joinRight, joinType = Inner, None, JoinHint.NONE) | ||
| } | ||
|
|
||
| // Called by `Dataset#join`, to attach the Dataset id to the logical plan, so that we | ||
| // can resolve column reference correctly later. See `ResolveDatasetColumnReference`. | ||
| private def createPlanWithDatasetId(): LogicalPlan = { | ||
| // The alias should start with `SubqueryAlias.HIDDEN_ALIAS_PREFIX`, so that `SubqueryAlias` can | ||
|
||
| // recognize it and keep the output qualifiers unchanged. | ||
| SubqueryAlias(s"${SubqueryAlias.HIDDEN_ALIAS_PREFIX}${Dataset.ID_PREFIX}_$id", logicalPlan) | ||
|
||
| } | ||
|
|
||
| private def prepareJoinPlan(left: Dataset[_], right: Dataset[_]): (LogicalPlan, LogicalPlan) = { | ||
| if (!sparkSession.sessionState.conf.getConf(SQLConf.RESOLVE_DATASET_COLUMN_REFERENCE)) { | ||
| // If the config is disabled, do nothing. | ||
| (left.logicalPlan, right.logicalPlan) | ||
| } else { | ||
| (left.createPlanWithDatasetId(), right.createPlanWithDatasetId()) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -949,10 +976,11 @@ class Dataset[T] private[sql]( | |
| * @since 2.0.0 | ||
| */ | ||
| def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = { | ||
| val (joinLeft, joinRight) = prepareJoinPlan(this, right) | ||
| // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right | ||
| // by creating a new instance for one of the branch. | ||
| val joined = sparkSession.sessionState.executePlan( | ||
| Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None, JoinHint.NONE)) | ||
| Join(joinLeft, joinRight, joinType = JoinType(joinType), None, JoinHint.NONE)) | ||
| .analyzed.asInstanceOf[Join] | ||
|
|
||
| withPlan { | ||
|
|
@@ -1014,8 +1042,9 @@ class Dataset[T] private[sql]( | |
|
|
||
| // Trigger analysis so in the case of self-join, the analyzer will clone the plan. | ||
| // After the cloning, left and right side will have distinct expression ids. | ||
| val (joinLeft, joinRight) = prepareJoinPlan(this, right) | ||
| val plan = withPlan( | ||
| Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr), JoinHint.NONE)) | ||
| Join(joinLeft, joinRight, JoinType(joinType), Some(joinExprs.expr), JoinHint.NONE)) | ||
| .queryExecution.analyzed.asInstanceOf[Join] | ||
|
|
||
| // If auto self join alias is disabled, return the plan. | ||
|
|
@@ -1024,9 +1053,7 @@ class Dataset[T] private[sql]( | |
| } | ||
|
|
||
| // If left/right have no output set intersection, return the plan. | ||
| val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed | ||
| val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed | ||
| if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { | ||
| if (this.logicalPlan.outputSet.intersect(right.logicalPlan.outputSet).isEmpty) { | ||
| return withPlan(plan) | ||
| } | ||
|
|
||
|
|
@@ -1289,10 +1316,24 @@ class Dataset[T] private[sql]( | |
| colRegex(colName) | ||
| } else { | ||
| val expr = resolve(colName) | ||
| Column(expr) | ||
| Column(addDataFrameIdToCol(expr)) | ||
|
||
| } | ||
| } | ||
|
|
||
| // Attach the dataset id and column position to the column reference, so that we can resolve it | ||
| // correctly in case of self-join. See `ResolveDatasetColumnReference`. | ||
| private def addDataFrameIdToCol(expr: NamedExpression): NamedExpression = expr match { | ||
| case a: AttributeReference | ||
| if sparkSession.sessionState.conf.getConf(SQLConf.RESOLVE_DATASET_COLUMN_REFERENCE) => | ||
| val metadata = new MetadataBuilder() | ||
| .withMetadata(a.metadata) | ||
| .putLong(Dataset.ID_PREFIX, id) | ||
| .putLong(Dataset.COL_POS_PREFIX, logicalPlan.output.indexWhere(a.semanticEquals)) | ||
| .build() | ||
| a.withMetadata(metadata) | ||
| case _ => expr | ||
| } | ||
|
|
||
| /** | ||
| * Selects column based on the column name specified as a regex and returns it as [[Column]]. | ||
| * @group untypedrel | ||
|
|
@@ -2297,11 +2338,16 @@ class Dataset[T] private[sql]( | |
| u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) | ||
| case Column(expr: Expression) => expr | ||
| } | ||
| val attrs = this.logicalPlan.output | ||
| val colsAfterDrop = attrs.filter { attr => | ||
| attr != expression | ||
|
||
| }.map(attr => Column(attr)) | ||
| select(colsAfterDrop : _*) | ||
| expression match { | ||
| case a: Attribute => | ||
| val attrs = this.logicalPlan.output | ||
| val colsAfterDrop = attrs.filter { attr => | ||
| !attr.semanticEquals(a) | ||
| }.map(attr => Column(attr)) | ||
| select(colsAfterDrop : _*) | ||
|
|
||
| case _ => toDF() | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,162 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.execution.analysis | ||
|
|
||
| import scala.collection.mutable | ||
| import scala.util.Try | ||
|
|
||
| import org.apache.spark.sql.{AnalysisException, Column, Dataset} | ||
| import org.apache.spark.sql.catalyst.AliasIdentifier | ||
| import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Equality, EqualNullSafe, EqualTo} | ||
| import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} | ||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.internal.SQLConf | ||
|
|
||
| /** | ||
| * Resolves the Dataset column reference by traversing the query plan and finding the plan subtree | ||
| * of the Dataset that the column reference belongs to. | ||
| * | ||
| * Dataset column reference is simply an [[AttributeReference]] that is returned by `Dataset#col`. | ||
| * Most of time we don't need to do anything special, as [[AttributeReference]] can point to | ||
| * the column precisely. However, in case of self-join, the analyzer generates | ||
| * [[AttributeReference]] with new expr IDs for the right side plan of the join. If the Dataset | ||
| * column reference points to a column in the right side plan of a self-join, we need to replace it | ||
| * with the corresponding newly generated [[AttributeReference]]. | ||
| */ | ||
| class ResolveDatasetColumnReference(conf: SQLConf) extends Rule[LogicalPlan] { | ||
|
|
||
| // Dataset column reference is an `AttributeReference` with 2 special metadata. | ||
| private def isColumnReference(a: AttributeReference): Boolean = { | ||
| a.metadata.contains(Dataset.ID_PREFIX) && a.metadata.contains(Dataset.COL_POS_PREFIX) | ||
| } | ||
|
|
||
| private case class ColumnReference(datasetId: Long, colPos: Int) | ||
|
|
||
| private def toColumnReference(a: AttributeReference): ColumnReference = { | ||
| ColumnReference( | ||
| a.metadata.getLong(Dataset.ID_PREFIX), | ||
| a.metadata.getLong(Dataset.COL_POS_PREFIX).toInt) | ||
| } | ||
|
|
||
| override def apply(plan: LogicalPlan): LogicalPlan = { | ||
| if (!conf.getConf(SQLConf.RESOLVE_DATASET_COLUMN_REFERENCE)) return plan | ||
|
|
||
| // We always remove the special metadata from `AttributeReference` at the end of this rule, so | ||
| // Dataset column reference only exists in the root node via Dataset transformations like | ||
| // `Dataset#select`. | ||
| val colRefAttrs = plan.expressions.flatMap(_.collect { | ||
| case a: AttributeReference if isColumnReference(a) => a | ||
| }) | ||
|
|
||
| if (colRefAttrs.isEmpty) { | ||
| plan | ||
| } else { | ||
| val colRefs = colRefAttrs.map(toColumnReference).distinct | ||
| // Keeps the mapping between the column reference and the actual attribute it points to. This | ||
| // will be used to replace the column references with actual attributes later. | ||
| val colRefToActualAttr = new mutable.HashMap[ColumnReference, AttributeReference]() | ||
| // Keeps the column references that points to more than one actual attributes. | ||
| val ambiguousColRefs = new mutable.HashMap[ColumnReference, Seq[AttributeReference]]() | ||
| // We only care about `SubqueryAlias` referring to Datasets which produces the column | ||
| // references that we want to resolve here. | ||
| val dsIdSet = colRefs.map(_.datasetId).toSet | ||
| // If a column reference points to an attribute that is not present in the plan's inputSet, we | ||
| // should ignore it as it's invalid. | ||
| val inputSet = plan.inputSet | ||
|
|
||
| plan.foreach { | ||
| // We only add the special `SubqueryAlias` to attach the dataset id for self-join. After | ||
| // self-join resolving, the child of `SubqueryAlias` should have generated new | ||
| // `AttributeReference`, and we need to resolve column reference with them. | ||
| case SubqueryAlias(DatasetIdAlias(id), child) if dsIdSet.contains(id) => | ||
| colRefs.foreach { case ref => | ||
| if (id == ref.datasetId) { | ||
| if (ref.colPos < 0 || ref.colPos >= child.output.length) { | ||
| throw new IllegalStateException("[BUG] Hit an invalid Dataset column reference: " + | ||
| s"$ref. Please open a JIRA ticket to report it.") | ||
| } else { | ||
| val actualAttr = child.output(ref.colPos).asInstanceOf[AttributeReference] | ||
| if (inputSet.contains(actualAttr)) { | ||
| // Record the ambiguous column references. We will deal with them later. | ||
| if (ambiguousColRefs.contains(ref)) { | ||
| assert(!colRefToActualAttr.contains(ref)) | ||
| ambiguousColRefs(ref) = ambiguousColRefs(ref) :+ actualAttr | ||
| } else if (colRefToActualAttr.contains(ref)) { | ||
| ambiguousColRefs(ref) = Seq(colRefToActualAttr.remove(ref).get, actualAttr) | ||
| } else { | ||
| colRefToActualAttr(ref) = actualAttr | ||
| } | ||
| } | ||
|
|
||
| } | ||
| } | ||
| } | ||
|
|
||
| case _ => | ||
| } | ||
|
|
||
| val deAmbiguousColsRefs = new mutable.HashSet[ColumnReference]() | ||
| val newPlan = plan.transformExpressions { | ||
| case e @ Equality(a: AttributeReference, b: AttributeReference) | ||
| if isColumnReference(a) && isColumnReference(b) && a.sameRef(b) => | ||
| val colRefA = toColumnReference(a) | ||
| val colRefB = toColumnReference(a) | ||
| val maybeActualAttrs = ambiguousColRefs.get(colRefA) | ||
| if (colRefA == colRefB && maybeActualAttrs.exists(_.length == 2)) { | ||
| deAmbiguousColsRefs += colRefA | ||
| if (e.isInstanceOf[EqualTo]) { | ||
| EqualTo(maybeActualAttrs.get.head, maybeActualAttrs.get.last) | ||
| } else { | ||
| EqualNullSafe(maybeActualAttrs.get.head, maybeActualAttrs.get.last) | ||
| } | ||
| } else { | ||
| e | ||
| } | ||
|
|
||
| case a: AttributeReference if isColumnReference(a) => | ||
| val actualAttr = colRefToActualAttr.getOrElse(toColumnReference(a), a) | ||
| // Remove the special metadata from this `AttributeReference`, as the column reference | ||
| // resolving is done. | ||
| Column.stripColumnReferenceMetadata(actualAttr) | ||
| } | ||
|
|
||
| ambiguousColRefs.filterKeys(!deAmbiguousColsRefs.contains(_)).foreach { case (ref, _) => | ||
| val originalAttr = colRefAttrs.find(attr => toColumnReference(attr) == ref).get | ||
| throw new AnalysisException(s"Column $originalAttr is ambiguous. It's probably " + | ||
| "because you joined several Datasets together, and some of these Datasets are the " + | ||
| "same. This column points to one of the Datasets but Spark is unable to figure out " + | ||
| "which Datasset. Please alias the Datasets with different names via `Dataset.as` " + | ||
| "before joining them, and specify the column using qualified name, e.g. " + | ||
| """`df.as("a").join(df.as("b"), $"a.id" > $"b.id")`.""") | ||
| } | ||
|
|
||
| newPlan | ||
| } | ||
| } | ||
|
|
||
| object DatasetIdAlias { | ||
| def unapply(alias: AliasIdentifier): Option[Long] = { | ||
| val expectedPrefix = SubqueryAlias.HIDDEN_ALIAS_PREFIX + Dataset.ID_PREFIX | ||
| if (alias.database.isEmpty && alias.identifier.startsWith(expectedPrefix)) { | ||
| Try(alias.identifier.drop(expectedPrefix.length + 1).toLong).toOption | ||
| } else { | ||
| None | ||
| } | ||
| } | ||
| } | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.