Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ class Analyzer(
gid: Expression): Expression = {
expr transform {
case e: GroupingID =>
if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) {
if (e.groupByExprs.isEmpty ||
e.groupByExprs.map(_.canonicalized) == groupByExprs.map(_.canonicalized)) {
Alias(gid, toPrettySQL(e))()
} else {
throw new AnalysisException(
Expand Down Expand Up @@ -952,6 +953,8 @@ class Analyzer(
// To resolve duplicate expression IDs for Join and Intersect
case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
// intersect/except will be rewritten to join at the begininng of optimizer. Here we need to
// deduplicate the right side plan, so that we won't produce an invalid self-join later.
case i @ Intersect(left, right, _) if !i.duplicateResolved =>
i.copy(right = dedupRight(left, right))
case e @ Except(left, right, _) if !e.duplicateResolved =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -951,13 +951,24 @@ case class SubqueryAlias(
def alias: String = name.identifier

override def output: Seq[Attribute] = {
val qualifierList = name.database.map(Seq(_, alias)).getOrElse(Seq(alias))
child.output.map(_.withQualifier(qualifierList))
if (isHiddenAlias) {
child.output
} else {
val qualifierList = name.database.map(Seq(_, alias)).getOrElse(Seq(alias))
child.output.map(_.withQualifier(qualifierList))
}
}

override def doCanonicalize(): LogicalPlan = child.canonicalized

def isHiddenAlias: Boolean = {
name.database.isEmpty && name.identifier.startsWith(SubqueryAlias.HIDDEN_ALIAS_PREFIX)
}
}

object SubqueryAlias {
val HIDDEN_ALIAS_PREFIX = "__hidden_alias"

def apply(
identifier: String,
child: LogicalPlan): SubqueryAlias = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val RESOLVE_DATASET_COLUMN_REFERENCE =
buildConf("spark.sql.analyzer.resolveDatasetColumnReference")
.doc("When true, resolve Dataset column reference in case of self-join.")
.internal()
.booleanConf
.createWithDefault(true)

// Whether to retain group by columns or not in GroupedData.agg.
val DATAFRAME_RETAIN_GROUP_COLUMNS = buildConf("spark.sql.retainGroupColumns")
.internal()
Expand Down
18 changes: 16 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ private[sql] object Column {
case expr => toPrettySQL(expr)
}
}

private[sql] def stripColumnReferenceMetadata(a: AttributeReference): AttributeReference = {
val metadataWithoutId = new MetadataBuilder()
.withMetadata(a.metadata)
.remove(Dataset.ID_PREFIX)
.remove(Dataset.COL_POS_PREFIX)
.build()
a.withMetadata(metadataWithoutId)
}
}

/**
Expand Down Expand Up @@ -144,11 +153,16 @@ class Column(val expr: Expression) extends Logging {
override def toString: String = toPrettySQL(expr)

override def equals(that: Any): Boolean = that match {
case that: Column => that.expr.equals(this.expr)
case that: Column => that.normalizedExpr().equals(this.normalizedExpr())
case _ => false
}

override def hashCode: Int = this.expr.hashCode()
override def hashCode: Int = this.normalizedExpr().hashCode()

private def normalizedExpr(): Expression = expr match {
case a: AttributeReference => Column.stripColumnReferenceMetadata(a)
case _ => expr
}

/** Creates a column based on the given expression. */
private def withExpr(newExpr: Expression): Column = new Column(newExpr)
Expand Down
70 changes: 58 additions & 12 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, FileTable}
import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
Expand All @@ -62,6 +63,11 @@ import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.Utils

private[sql] object Dataset {
val curId = new java.util.concurrent.atomic.AtomicLong()

val ID_PREFIX = "__dataset_id"
val COL_POS_PREFIX = "__col_position"

def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = {
val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]])
// Eagerly bind the encoder so we verify that the encoder matches the underlying
Expand Down Expand Up @@ -183,6 +189,9 @@ class Dataset[T] private[sql](
@DeveloperApi @Unstable @transient val encoder: Encoder[T])
extends Serializable {

// A globally unique id for this Dataset.
private val id = Dataset.curId.getAndIncrement()

queryExecution.assertAnalyzed()

// Note for Spark contributors: if adding or updating any action in `Dataset`, please make sure
Expand Down Expand Up @@ -873,7 +882,25 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
def join(right: Dataset[_]): DataFrame = withPlan {
Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE)
val (joinLeft, joinRight) = prepareJoinPlan(this, right)
Join(joinLeft, joinRight, joinType = Inner, None, JoinHint.NONE)
}

// Called by `Dataset#join`, to attach the Dataset id to the logical plan, so that we
// can resolve column reference correctly later. See `ResolveDatasetColumnReference`.
private def createPlanWithDatasetId(): LogicalPlan = {
// The alias should start with `SubqueryAlias.HIDDEN_ALIAS_PREFIX`, so that `SubqueryAlias` can
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not a big fan of using special strings as markers, I'd rather introduce a new attribute to carry on the information.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's at plan level, we can't use attributes to carry information here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant adding a new attribute to the SubqueryPlan case class, why isn't it possible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible, but I'm a little worried about changing the constructor of a widely used plan node like SubqueryAlais. It's also a little weird to put the dataset concept into the catalyst module.

Another idea I've thought about is adding a new no-op plan node to carry the dataset id. I gave it up because no-op node can be troublesome according to the past experience with ResolvedHint and AnalysisBarrier.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another idea I've thought about is adding a new no-op plan node to carry the dataset id

Yes, that's exactly the reason why I'd like to avoid adding this node. I think I'd prefer a no-op node if this is really needed actually, since this is used also in other places and we may introduce side-effects using it in a different way from what it is intended to do now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan I was thinking: in order to avoid adding another "placeholder" plan, since here we are dealing only with joins, what about adding a leftDatasetId and rightDatasetId to the join operator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works, but I'd rather hack into SubqueryAlias instead of Join, as Join node is more widely used...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and things can get tricky if join reorder happens.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and things can get tricky if join reorder happens.

yes, I also thought of this, but still I am not very confident about using SubqueryAlias: that node is meant for a different purpose and I think ew already have things used in a hacky way and this makes more error prone any future change, because we need to think not only to the logic but also we need to remember all the "hacks" done with it. Honestly rather than using a SubqueryAlias, if no other option works, I'd prefer adding a new node (maybe a sublass of it if we need common functionalities).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, SubqueryAlias might be safer than new node. Because it is well known to Spark SQL community. A new node which isn't aware by others is more dangerous.

// recognize it and keep the output qualifiers unchanged.
SubqueryAlias(s"${SubqueryAlias.HIDDEN_ALIAS_PREFIX}${Dataset.ID_PREFIX}_$id", logicalPlan)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I am missing something. Sorry if the question is dumb, but I can't get why do we need to add this "special" node for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to attach the dataset id to the logical plan, so that we can resolve column reference later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather add the datasetId info to the LogicalPlan and avoid introducing a new plan here. I think it would be easier in this way to generalize this approach to other cases when the same problem may arise. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we add dataset id to every kind of LogicalPlan?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that could be an option. And for the moment we could add it only to the child/children of a join since we only need it there. But I see there is no guarantee that the plan(s) are not replaces/removed during the analysis/optimization phase, so it may not be doable indeed.

}

private def prepareJoinPlan(left: Dataset[_], right: Dataset[_]): (LogicalPlan, LogicalPlan) = {
if (!sparkSession.sessionState.conf.getConf(SQLConf.RESOLVE_DATASET_COLUMN_REFERENCE)) {
// If the config is disabled, do nothing.
(left.logicalPlan, right.logicalPlan)
} else {
(left.createPlanWithDatasetId(), right.createPlanWithDatasetId())
}
}

/**
Expand Down Expand Up @@ -949,10 +976,11 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = {
val (joinLeft, joinRight) = prepareJoinPlan(this, right)
// Analyze the self join. The assumption is that the analyzer will disambiguate left vs right
// by creating a new instance for one of the branch.
val joined = sparkSession.sessionState.executePlan(
Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None, JoinHint.NONE))
Join(joinLeft, joinRight, joinType = JoinType(joinType), None, JoinHint.NONE))
.analyzed.asInstanceOf[Join]

withPlan {
Expand Down Expand Up @@ -1014,8 +1042,9 @@ class Dataset[T] private[sql](

// Trigger analysis so in the case of self-join, the analyzer will clone the plan.
// After the cloning, left and right side will have distinct expression ids.
val (joinLeft, joinRight) = prepareJoinPlan(this, right)
val plan = withPlan(
Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr), JoinHint.NONE))
Join(joinLeft, joinRight, JoinType(joinType), Some(joinExprs.expr), JoinHint.NONE))
.queryExecution.analyzed.asInstanceOf[Join]

// If auto self join alias is disabled, return the plan.
Expand All @@ -1024,9 +1053,7 @@ class Dataset[T] private[sql](
}

// If left/right have no output set intersection, return the plan.
val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed
val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed
if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) {
if (this.logicalPlan.outputSet.intersect(right.logicalPlan.outputSet).isEmpty) {
return withPlan(plan)
}

Expand Down Expand Up @@ -1289,10 +1316,24 @@ class Dataset[T] private[sql](
colRegex(colName)
} else {
val expr = resolve(colName)
Column(expr)
Column(addDataFrameIdToCol(expr))
Copy link
Member

@viirya viirya Apr 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As indicated by the change in Analyzer, after this two attributes we considered the same, now are different if we compare it without canonicalization.

Not sure if it will be an issue. But it is counterintuitive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Internally, I think we should never compare 2 attributes by equal not semanticEqual.

Externally, if users rely on the equality of Column, yes this PR may break it under some cases. If this matters, I can add a migration guide or update Column.equal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that metadata is not well aware. It might be confusing to some people that the comparison fails if they know it works before. An update to migration guide or Column.equal both sounds good.

}
}

// Attach the dataset id and column position to the column reference, so that we can resolve it
// correctly in case of self-join. See `ResolveDatasetColumnReference`.
private def addDataFrameIdToCol(expr: NamedExpression): NamedExpression = expr match {
case a: AttributeReference
if sparkSession.sessionState.conf.getConf(SQLConf.RESOLVE_DATASET_COLUMN_REFERENCE) =>
val metadata = new MetadataBuilder()
.withMetadata(a.metadata)
.putLong(Dataset.ID_PREFIX, id)
.putLong(Dataset.COL_POS_PREFIX, logicalPlan.output.indexWhere(a.semanticEquals))
.build()
a.withMetadata(metadata)
case _ => expr
}

/**
* Selects column based on the column name specified as a regex and returns it as [[Column]].
* @group untypedrel
Expand Down Expand Up @@ -2297,11 +2338,16 @@ class Dataset[T] private[sql](
u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u)
case Column(expr: Expression) => expr
}
val attrs = this.logicalPlan.output
val colsAfterDrop = attrs.filter { attr =>
attr != expression
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drop should not match attribute by equality. We should use semanticEquals.

}.map(attr => Column(attr))
select(colsAfterDrop : _*)
expression match {
case a: Attribute =>
val attrs = this.logicalPlan.output
val colsAfterDrop = attrs.filter { attr =>
!attr.semanticEquals(a)
}.map(attr => Column(attr))
select(colsAfterDrop : _*)

case _ => toDF()
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.analysis

import scala.collection.mutable
import scala.util.Try

import org.apache.spark.sql.{AnalysisException, Column, Dataset}
import org.apache.spark.sql.catalyst.AliasIdentifier
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Equality, EqualNullSafe, EqualTo}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf

/**
* Resolves the Dataset column reference by traversing the query plan and finding the plan subtree
* of the Dataset that the column reference belongs to.
*
* Dataset column reference is simply an [[AttributeReference]] that is returned by `Dataset#col`.
* Most of time we don't need to do anything special, as [[AttributeReference]] can point to
* the column precisely. However, in case of self-join, the analyzer generates
* [[AttributeReference]] with new expr IDs for the right side plan of the join. If the Dataset
* column reference points to a column in the right side plan of a self-join, we need to replace it
* with the corresponding newly generated [[AttributeReference]].
*/
class ResolveDatasetColumnReference(conf: SQLConf) extends Rule[LogicalPlan] {

// Dataset column reference is an `AttributeReference` with 2 special metadata.
private def isColumnReference(a: AttributeReference): Boolean = {
a.metadata.contains(Dataset.ID_PREFIX) && a.metadata.contains(Dataset.COL_POS_PREFIX)
}

private case class ColumnReference(datasetId: Long, colPos: Int)

private def toColumnReference(a: AttributeReference): ColumnReference = {
ColumnReference(
a.metadata.getLong(Dataset.ID_PREFIX),
a.metadata.getLong(Dataset.COL_POS_PREFIX).toInt)
}

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.getConf(SQLConf.RESOLVE_DATASET_COLUMN_REFERENCE)) return plan

// We always remove the special metadata from `AttributeReference` at the end of this rule, so
// Dataset column reference only exists in the root node via Dataset transformations like
// `Dataset#select`.
val colRefAttrs = plan.expressions.flatMap(_.collect {
case a: AttributeReference if isColumnReference(a) => a
})

if (colRefAttrs.isEmpty) {
plan
} else {
val colRefs = colRefAttrs.map(toColumnReference).distinct
// Keeps the mapping between the column reference and the actual attribute it points to. This
// will be used to replace the column references with actual attributes later.
val colRefToActualAttr = new mutable.HashMap[ColumnReference, AttributeReference]()
// Keeps the column references that points to more than one actual attributes.
val ambiguousColRefs = new mutable.HashMap[ColumnReference, Seq[AttributeReference]]()
// We only care about `SubqueryAlias` referring to Datasets which produces the column
// references that we want to resolve here.
val dsIdSet = colRefs.map(_.datasetId).toSet
// If a column reference points to an attribute that is not present in the plan's inputSet, we
// should ignore it as it's invalid.
val inputSet = plan.inputSet

plan.foreach {
// We only add the special `SubqueryAlias` to attach the dataset id for self-join. After
// self-join resolving, the child of `SubqueryAlias` should have generated new
// `AttributeReference`, and we need to resolve column reference with them.
case SubqueryAlias(DatasetIdAlias(id), child) if dsIdSet.contains(id) =>
colRefs.foreach { case ref =>
if (id == ref.datasetId) {
if (ref.colPos < 0 || ref.colPos >= child.output.length) {
throw new IllegalStateException("[BUG] Hit an invalid Dataset column reference: " +
s"$ref. Please open a JIRA ticket to report it.")
} else {
val actualAttr = child.output(ref.colPos).asInstanceOf[AttributeReference]
if (inputSet.contains(actualAttr)) {
// Record the ambiguous column references. We will deal with them later.
if (ambiguousColRefs.contains(ref)) {
assert(!colRefToActualAttr.contains(ref))
ambiguousColRefs(ref) = ambiguousColRefs(ref) :+ actualAttr
} else if (colRefToActualAttr.contains(ref)) {
ambiguousColRefs(ref) = Seq(colRefToActualAttr.remove(ref).get, actualAttr)
} else {
colRefToActualAttr(ref) = actualAttr
}
}

}
}
}

case _ =>
}

val deAmbiguousColsRefs = new mutable.HashSet[ColumnReference]()
val newPlan = plan.transformExpressions {
case e @ Equality(a: AttributeReference, b: AttributeReference)
if isColumnReference(a) && isColumnReference(b) && a.sameRef(b) =>
val colRefA = toColumnReference(a)
val colRefB = toColumnReference(a)
val maybeActualAttrs = ambiguousColRefs.get(colRefA)
if (colRefA == colRefB && maybeActualAttrs.exists(_.length == 2)) {
deAmbiguousColsRefs += colRefA
if (e.isInstanceOf[EqualTo]) {
EqualTo(maybeActualAttrs.get.head, maybeActualAttrs.get.last)
} else {
EqualNullSafe(maybeActualAttrs.get.head, maybeActualAttrs.get.last)
}
} else {
e
}

case a: AttributeReference if isColumnReference(a) =>
val actualAttr = colRefToActualAttr.getOrElse(toColumnReference(a), a)
// Remove the special metadata from this `AttributeReference`, as the column reference
// resolving is done.
Column.stripColumnReferenceMetadata(actualAttr)
}

ambiguousColRefs.filterKeys(!deAmbiguousColsRefs.contains(_)).foreach { case (ref, _) =>
val originalAttr = colRefAttrs.find(attr => toColumnReference(attr) == ref).get
throw new AnalysisException(s"Column $originalAttr is ambiguous. It's probably " +
"because you joined several Datasets together, and some of these Datasets are the " +
"same. This column points to one of the Datasets but Spark is unable to figure out " +
"which Datasset. Please alias the Datasets with different names via `Dataset.as` " +
"before joining them, and specify the column using qualified name, e.g. " +
"""`df.as("a").join(df.as("b"), $"a.id" > $"b.id")`.""")
}

newPlan
}
}

object DatasetIdAlias {
def unapply(alias: AliasIdentifier): Option[Long] = {
val expectedPrefix = SubqueryAlias.HIDDEN_ALIAS_PREFIX + Dataset.ID_PREFIX
if (alias.database.isEmpty && alias.identifier.startsWith(expectedPrefix)) {
Try(alias.identifier.drop(expectedPrefix.length + 1).toLong).toOption
} else {
None
}
}
}
}
Loading