diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 4aac2c6c706..6b6d7cb4d17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -156,9 +156,9 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { viewName, colName, expectedNumCandidates, matched, viewDDL) } matched(ordinal) - + /* case u @ UnresolvedAttributeWithTag(attr, id) => - resolveOnDatasetId(id, attr.name).getOrElse(attr) + resolveOnDatasetId(id, attr.name).getOrElse(attr) */ case u @ UnresolvedAttribute(nameParts) => val result = withPosition(u) { @@ -577,13 +577,34 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { private def resolveDataFrameColumn( u: UnresolvedAttribute, q: Seq[LogicalPlan]): Option[NamedExpression] = { - val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG) - if (planIdOpt.isEmpty) return None - val planId = planIdOpt.get - logDebug(s"Extract plan_id $planId from $u") + + val id_IdChecker_initPlans_endRecOpt = u.getTagValue(LogicalPlan.ATTRIBUTE_DATASET_ID_TAG).map( + id => { + val initPlans = if (q.size == 1) { + val binaryNodeOpt = q.head.collectFirst { + case bn: BinaryNode => bn + } + binaryNodeOpt.map(_.children).getOrElse(q) + } else { + q + } + + (id, (lp: LogicalPlan, id: Long) => lp.getTagValue(LogicalPlan.DATASET_ID_TAG). + exists(_.contains(id)), initPlans, (lp: LogicalPlan) => lp.children.size > 1) + } + ).orElse( + u.getTagValue(LogicalPlan.PLAN_ID_TAG).map( + (_, (lp: LogicalPlan, id: Long) => lp.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id), + q, (lp: LogicalPlan) => false) + )) + + if (id_IdChecker_initPlans_endRecOpt.isEmpty) return None + val (id, idChecker, startPlans, endRecursion) = id_IdChecker_initPlans_endRecOpt.get + logDebug(s"Extract plan_id $id from $u") val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).nonEmpty - val (resolved, matched) = resolveDataFrameColumnByPlanId(u, planId, isMetadataAccess, q) + val (resolved, matched) = resolveDataFrameColumnByPlanId(u, id, isMetadataAccess, startPlans, + idChecker, endRecursion) if (!matched) { // Can not find the target plan node with plan id, e.g. // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) @@ -598,8 +619,11 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { u: UnresolvedAttribute, id: Long, isMetadataAccess: Boolean, - q: Seq[LogicalPlan]): (Option[NamedExpression], Boolean) = { - q.iterator.map(resolveDataFrameColumnRecursively(u, id, isMetadataAccess, _)) + q: Seq[LogicalPlan], + idChecker: (LogicalPlan, Long) => Boolean, + endRecusrion: LogicalPlan => Boolean): (Option[NamedExpression], Boolean) = { + q.iterator.map(resolveDataFrameColumnRecursively(u, id, isMetadataAccess, _, idChecker, + endRecusrion)) .foldLeft((Option.empty[NamedExpression], false)) { case ((r1, m1), (r2, m2)) => if (r1.nonEmpty && r2.nonEmpty) { @@ -613,8 +637,10 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { u: UnresolvedAttribute, id: Long, isMetadataAccess: Boolean, - p: LogicalPlan): (Option[NamedExpression], Boolean) = { - val (resolved, matched) = if (p.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) { + p: LogicalPlan, + idChecker: (LogicalPlan, Long) => Boolean, + endRecusrion: LogicalPlan => Boolean): (Option[NamedExpression], Boolean) = { + val (resolved, matched) = if (idChecker(p, id)) { val resolved = try { if (!isMetadataAccess) { p.resolve(u.nameParts, conf.resolver) @@ -629,8 +655,10 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { None } (resolved, true) + } else if (endRecusrion(p)) { + (None, false) } else { - resolveDataFrameColumnByPlanId(u, id, isMetadataAccess, p.children) + resolveDataFrameColumnByPlanId(u, id, isMetadataAccess, p.children, idChecker, endRecusrion) } // In self join case like: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 397351e0c1f..a46105dff85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -267,7 +267,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un nameParts.length == 1 && nameParts.head.equalsIgnoreCase(token) } } - +/* case class UnresolvedAttributeWithTag(attribute: Attribute, datasetId: Long) extends Attribute with Unevaluable { def name: String = attribute.name @@ -309,6 +309,8 @@ case class UnresolvedAttributeWithTag(attribute: Attribute, datasetId: Long) ext def equalsIgnoreCase(token: String): Boolean = token.equalsIgnoreCase(attribute.name) } + */ + object UnresolvedAttribute extends AttributeNameParser { /** * Creates an [[UnresolvedAttribute]], parsing segments separated by dots ('.'). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index a9b130c981a..45cce596682 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -201,6 +201,7 @@ object LogicalPlan { private[spark] val PLAN_ID_TAG = TreeNodeTag[Long]("plan_id") private[spark] val IS_METADATA_COL = TreeNodeTag[Unit]("is_metadata_col") private[spark] val DATASET_ID_TAG = TreeNodeTag[mutable.HashSet[Long]]("dataset_id") + private[spark] val ATTRIBUTE_DATASET_ID_TAG = TreeNodeTag[Long]("dataset_id") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1186cd00ec2..1f1edfb0e0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1193,7 +1193,11 @@ class Dataset[T] private[sql]( val rightLegWrong = isIncorrectlyResolved(attr, planPart1.right.outputSet, rightTagIdMap.getOrElse(HashSet.empty[Long])) if (!planPart1.outputSet.contains(attr) || leftLegWrong || rightLegWrong) { - UnresolvedAttributeWithTag(attr, attr.metadata.getLong(DATASET_ID_KEY)) + val ua = UnresolvedAttribute(attr.name) + ua.copyTagsFrom(attr) + ua.setTagValue(LogicalPlan.ATTRIBUTE_DATASET_ID_TAG, + attr.metadata.getLong(DATASET_ID_KEY)) + ua } else { attr } @@ -1337,7 +1341,10 @@ class Dataset[T] private[sql]( joined.left.output(index) case a: AttributeReference if a.metadata.contains(Dataset.DATASET_ID_KEY) => - UnresolvedAttributeWithTag(a, a.metadata.getLong(Dataset.DATASET_ID_KEY)) + val ua = UnresolvedAttribute(a.name) + ua.copyTagsFrom(a) + ua.setTagValue(LogicalPlan.ATTRIBUTE_DATASET_ID_TAG, a.metadata.getLong(DATASET_ID_KEY)) + ua } val rightAsOfExpr = rightAsOf.expr.transformUp { case a: AttributeReference if other.logicalPlan.outputSet.contains(a) => @@ -1345,7 +1352,10 @@ class Dataset[T] private[sql]( joined.right.output(index) case a: AttributeReference if a.metadata.contains(Dataset.DATASET_ID_KEY) => - UnresolvedAttributeWithTag(a, a.metadata.getLong(Dataset.DATASET_ID_KEY)) + val ua = UnresolvedAttribute(a.name) + ua.copyTagsFrom(a) + ua.setTagValue(LogicalPlan.ATTRIBUTE_DATASET_ID_TAG, a.metadata.getLong(DATASET_ID_KEY)) + ua } withPlan { AsOfJoin( @@ -1614,7 +1624,11 @@ class Dataset[T] private[sql]( case attr: AttributeReference if attr.metadata.contains(DATASET_ID_KEY) && (!inputForProj.contains(attr) || isIncorrectlyResolved(attr, inputForProj, HashSet(id))) => - UnresolvedAttributeWithTag(attr, attr.metadata.getLong(DATASET_ID_KEY)) + val ua = UnresolvedAttribute(attr.name) + ua.copyTagsFrom(attr) + ua.setTagValue(LogicalPlan.ATTRIBUTE_DATASET_ID_TAG, attr.metadata.getLong(DATASET_ID_KEY)) + ua + }).asInstanceOf[NamedExpression]) Project(namedExprs, logicalPlan) }