Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
expr: Expression,
resolveColumnByName: Seq[String] => Option[Expression],
getAttrCandidates: () => Seq[Attribute],
resolveOnDatasetId: (Long, String) => Option[NamedExpression],
throws: Boolean,
includeLastResort: Boolean): Expression = {
def innerResolve(e: Expression, isTopLevel: Boolean): Expression = withOrigin(e.origin) {
Expand All @@ -156,6 +157,9 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
}
matched(ordinal)

case u @ UnresolvedAttributeWithTag(attr, id) =>
resolveOnDatasetId(id, attr.name).getOrElse(attr)

case u @ UnresolvedAttribute(nameParts) =>
val result = withPosition(u) {
resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map {
Expand Down Expand Up @@ -452,6 +456,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
plan.resolve(nameParts, conf.resolver)
},
getAttrCandidates = () => plan.output,
resolveOnDatasetId = (_, _) => None,
throws = throws,
includeLastResort = includeLastResort)
}
Expand All @@ -477,6 +482,57 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
assert(q.children.length == 1)
q.children.head.output
},

resolveOnDatasetId = (datasetid: Long, name: String) => {
def findUnaryNodeMatchingTagId(lp: LogicalPlan): Option[(LogicalPlan, Int)] = {
var currentLp = lp
var depth = 0
while(true) {
if (currentLp.getTagValue(LogicalPlan.DATASET_ID_TAG).exists(_.contains(datasetid))) {
return Option(currentLp, depth)
} else {
if (currentLp.children.size == 1) {
currentLp = currentLp.children.head
} else {
// leaf node or node is a binary node
return None
}
}
depth += 1
}
None
}

val binaryNodeOpt = q.collectFirst {
case bn: BinaryNode => bn
}

val resolveOnAttribs = binaryNodeOpt match {
case Some(bn) =>
val leftDefOpt = findUnaryNodeMatchingTagId(bn.left)
val rightDefOpt = findUnaryNodeMatchingTagId(bn.right)
(leftDefOpt, rightDefOpt) match {

case (None, Some((lp, _))) => lp.output

case (Some((lp, _)), None) => lp.output

case (Some((lp1, depth1)), Some((lp2, depth2))) =>
if (depth1 == depth2) {
q.children.head.output
} else if (depth1 < depth2) {
lp1.output
} else {
lp2.output
}

case _ => q.children.head.output
}

case _ => q.children.head.output
}
AttributeSeq.fromNormalOutput(resolveOnAttribs).resolve(Seq(name), conf.resolver)
},
throws = true,
includeLastResort = includeLastResort)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,47 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un
}
}

case class UnresolvedAttributeWithTag(attribute: Attribute, datasetId: Long) extends Attribute with
Unevaluable {
def name: String = attribute.name

override def exprId: ExprId = throw new UnresolvedException("exprId")

override def dataType: DataType = throw new UnresolvedException("dataType")

override def nullable: Boolean = throw new UnresolvedException("nullable")

override def qualifier: Seq[String] = throw new UnresolvedException("qualifier")

override lazy val resolved = false

override def newInstance(): UnresolvedAttributeWithTag = this

override def withNullability(newNullability: Boolean): UnresolvedAttributeWithTag = this

override def withQualifier(newQualifier: Seq[String]): UnresolvedAttributeWithTag = this

override def withName(newName: String): UnresolvedAttributeWithTag = this

override def withMetadata(newMetadata: Metadata): Attribute = this

override def withExprId(newExprId: ExprId): UnresolvedAttributeWithTag = this

override def withDataType(newType: DataType): Attribute = this

final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_ATTRIBUTE)

override def toString: String = s"'$name"

override def sql: String = attribute.sql

/**
* Returns true if this matches the token. This requires the attribute to only have one part in
* its name and that matches the given token in a case insensitive way.
*/
def equalsIgnoreCase(token: String): Boolean = token.equalsIgnoreCase(attribute.name)
}

object UnresolvedAttribute extends AttributeNameParser {
/**
* Creates an [[UnresolvedAttribute]], parsing segments separated by dots ('.').
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.plans.logical

import scala.collection.mutable

import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
Expand All @@ -30,7 +32,6 @@ import org.apache.spark.sql.catalyst.util.MetadataColumnHelper
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.{MapType, StructType}


abstract class LogicalPlan
extends QueryPlan[LogicalPlan]
with AnalysisHelper
Expand Down Expand Up @@ -199,6 +200,7 @@ object LogicalPlan {
// to the old code path.
private[spark] val PLAN_ID_TAG = TreeNodeTag[Long]("plan_id")
private[spark] val IS_METADATA_COL = TreeNodeTag[Unit]("is_metadata_col")
private[spark] val DATASET_ID_TAG = TreeNodeTag[mutable.HashSet[Long]]("dataset_id")
}

/**
Expand Down
92 changes: 86 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.api.r.RRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.resource.ResourceProfile
import org.apache.spark.sql.Dataset.DATASET_ID_KEY
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QueryPlanningTracker, ScalaReflection, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
Expand All @@ -47,7 +48,7 @@ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern}
import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
Expand All @@ -73,7 +74,7 @@ private[sql] object Dataset {
val curId = new java.util.concurrent.atomic.AtomicLong()
val DATASET_ID_KEY = "__dataset_id"
val COL_POS_KEY = "__col_position"
val DATASET_ID_TAG = TreeNodeTag[HashSet[Long]]("dataset_id")
val DATASET_ID_TAG = LogicalPlan.DATASET_ID_TAG

def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = {
val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]])
Expand Down Expand Up @@ -1150,10 +1151,10 @@ class Dataset[T] private[sql](

// Trigger analysis so in the case of self-join, the analyzer will clone the plan.
// After the cloning, left and right side will have distinct expression ids.

val plan = withPlan(
Join(logicalPlan, right.logicalPlan,
JoinType(joinType), joinExprs.map(_.expr), JoinHint.NONE))
.queryExecution.analyzed.asInstanceOf[Join]
tryAmbiguityResolution(right, joinExprs, joinType)
).queryExecution.analyzed.asInstanceOf[Join]

// If auto self join alias is disabled, return the plan.
if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) {
Expand All @@ -1174,6 +1175,32 @@ class Dataset[T] private[sql](
JoinWith.resolveSelfJoinCondition(sparkSession.sessionState.analyzer.resolver, plan)
}

private def tryAmbiguityResolution(
right: Dataset[_],
joinExprs: Option[Column],
joinType: String) = {
val planPart1 = withPlan(
Join(logicalPlan, right.logicalPlan,
JoinType(joinType), None, JoinHint.NONE))
.queryExecution.analyzed.asInstanceOf[Join]
val inputSet = planPart1.outputSet
val joinExprsRectified = joinExprs.map(_.expr transformUp {
case attr: AttributeReference if attr.metadata.contains(Dataset.DATASET_ID_KEY) =>
val attribTagId = attr.metadata.getLong(Dataset.DATASET_ID_KEY)
val leftTagIdMap = planPart1.left.getTagValue(LogicalPlan.DATASET_ID_TAG)
val rightTagIdMap = planPart1.right.getTagValue(LogicalPlan.DATASET_ID_TAG)
if (!inputSet.contains(attr) ||
(planPart1.left.outputSet.contains(attr) && !leftTagIdMap.contains(attribTagId)) ||
(planPart1.right.outputSet.contains(attr) && !rightTagIdMap.contains(attribTagId))) {
UnresolvedAttributeWithTag(attr, attribTagId)
} else {
attr
}
})

Join(planPart1.left, planPart1.right, JoinType(joinType), joinExprsRectified, JoinHint.NONE)
}

/**
* Join with another `DataFrame`, using the given join expression. The following performs
* a full outer join between `df1` and `df2`.
Expand Down Expand Up @@ -1308,12 +1335,20 @@ class Dataset[T] private[sql](
case a: AttributeReference if logicalPlan.outputSet.contains(a) =>
val index = logicalPlan.output.indexWhere(_.exprId == a.exprId)
joined.left.output(index)

case a: AttributeReference if a.metadata.contains(Dataset.DATASET_ID_KEY) =>
UnresolvedAttributeWithTag(a, a.metadata.getLong(Dataset.DATASET_ID_KEY))
}

val rightAsOfExpr = rightAsOf.expr.transformUp {
case a: AttributeReference if other.logicalPlan.outputSet.contains(a) =>
val index = other.logicalPlan.output.indexWhere(_.exprId == a.exprId)
joined.right.output(index)

case a: AttributeReference if a.metadata.contains(Dataset.DATASET_ID_KEY) =>
UnresolvedAttributeWithTag(a, a.metadata.getLong(Dataset.DATASET_ID_KEY))
}

withPlan {
AsOfJoin(
joined.left, joined.right,
Expand Down Expand Up @@ -1576,7 +1611,52 @@ class Dataset[T] private[sql](

case other => other
}
Project(untypedCols.map(_.named), logicalPlan)
val namedExprs = untypedCols.map(_.named)
val inputSet = logicalPlan.outputSet
val rectifiedNamedExprs = namedExprs.map(ne => ne match {

case al: Alias if !al.references.subsetOf(inputSet) || al.references.exists(attr =>
attr.metadata.contains(DATASET_ID_KEY) && attr.metadata.getLong(DATASET_ID_KEY) !=
inputSet.find(_.canonicalized == attr.canonicalized).map(x =>
if (x.metadata.contains(DATASET_ID_KEY)) {
x.metadata.getLong(DATASET_ID_KEY)
} else {
Dataset.this.id
}).get)
=>
val unresolvedExpr = al.child.transformUp {
case attr: AttributeReference if attr.metadata.contains(Dataset.DATASET_ID_KEY) &&
(!inputSet.contains(attr) || attr.metadata.getLong(DATASET_ID_KEY) !=
inputSet.find(_.canonicalized == attr.canonicalized).map(x =>
if (x.metadata.contains(DATASET_ID_KEY)) {
x.metadata.getLong(DATASET_ID_KEY)
} else {
Dataset.this.id
}).get)
=>
UnresolvedAttributeWithTag(attr, attr.metadata.getLong(Dataset.DATASET_ID_KEY))
}
val newAl = al.copy(child = unresolvedExpr, name = al.name)(exprId = al.exprId,
qualifier = al.qualifier, explicitMetadata = al.explicitMetadata,
nonInheritableMetadataKeys = al.nonInheritableMetadataKeys)
newAl.copyTagsFrom(al)
newAl

case attr: Attribute if attr.metadata.contains(Dataset.DATASET_ID_KEY) &&
(!inputSet.contains(attr) || attr.metadata.getLong(DATASET_ID_KEY) !=
inputSet.find(_.canonicalized == attr.canonicalized).map(x =>
if (x.metadata.contains(DATASET_ID_KEY)) {
x.metadata.getLong(DATASET_ID_KEY)
} else {
Dataset.this.id
}).get)
=>
UnresolvedAttributeWithTag(attr, attr.metadata.getLong(Dataset.DATASET_ID_KEY))

case _ => ne

})
Project(rectifiedNamedExprs, logicalPlan)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import scala.jdk.CollectionConverters._

import org.apache.spark.sql.catalyst.plans.logical.AsOfJoin
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -173,4 +174,23 @@ class DataFrameAsOfJoinSuite extends QueryTest
)
)
}

test("SPARK_47217: Dedup of relations can impact projected columns resolution") {
val (df1, df2) = prepareForAsOfJoin()
val join1 = df1.join(df2, df1.col("a") === df2.col("a")).select(df2.col("a"), df1.col("b"),
df2.col("b"), df1.col("a").as("aa"))

// In stock spark this would throw ambiguous column exception, even though it is not ambiguous
val asOfjoin2 = join1.joinAsOf(
df1, df1.col("a"), join1.col("a"), usingColumns = Seq.empty,
joinType = "left", tolerance = null, allowExactMatches = false, direction = "nearest")

asOfjoin2.queryExecution.assertAnalyzed()

val testDf = asOfjoin2.select(df1.col("a"))
val analyzed = testDf.queryExecution.analyzed
val attributeRefToCheck = analyzed.output.head
assert(analyzed.children(0).asInstanceOf[AsOfJoin].right.outputSet.
contains(attributeRefToCheck))
}
}
Loading