diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AliasResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AliasResolution.scala new file mode 100644 index 000000000000..fa3300d57de8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AliasResolution.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.analysis.MultiAlias +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + Attribute, + Cast, + Expression, + ExtractValue, + Generator, + GeneratorOuter, + Literal, + NamedExpression +} +import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ALIAS +import org.apache.spark.sql.catalyst.util.{toPrettySQL, AUTO_GENERATED_ALIAS} +import org.apache.spark.sql.types.MetadataBuilder + +object AliasResolution { + def hasUnresolvedAlias(exprs: Seq[NamedExpression]): Boolean = { + exprs.exists(_.exists(_.isInstanceOf[UnresolvedAlias])) + } + + def assignAliases(exprs: Seq[NamedExpression]): Seq[NamedExpression] = { + exprs + .map(_.transformUpWithPruning(_.containsPattern(UNRESOLVED_ALIAS)) { + case u: UnresolvedAlias => resolve(u) + }) + .asInstanceOf[Seq[NamedExpression]] + } + + def resolve(u: UnresolvedAlias): Expression = { + val UnresolvedAlias(child, optGenAliasFunc) = u + child match { + case ne: NamedExpression => ne + case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil) + case e if !e.resolved => u + case g: Generator => MultiAlias(g, Nil) + case c @ Cast(ne: NamedExpression, _, _, _) => Alias(c, ne.name)() + case e: ExtractValue if extractOnly(e) => Alias(e, toPrettySQL(e))() + case e if optGenAliasFunc.isDefined => + Alias(child, optGenAliasFunc.get.apply(e))() + case l: Literal => Alias(l, toPrettySQL(l))() + case e => + val metaForAutoGeneratedAlias = new MetadataBuilder() + .putString(AUTO_GENERATED_ALIAS, "true") + .build() + Alias(e, toPrettySQL(e))(explicitMetadata = Some(metaForAutoGeneratedAlias)) + } + } + + private def extractOnly(e: Expression): Boolean = e match { + case _: ExtractValue => e.children.forall(extractOnly) + case _: Literal => true + case _: Attribute => true + case _ => false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d1d04d411726..d2103d32a6a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.catalyst.util.{toPrettySQL, AUTO_GENERATED_ALIAS, CharVarcharUtils} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -444,62 +444,42 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor * Replaces [[UnresolvedAlias]]s with concrete aliases. */ object ResolveAliases extends Rule[LogicalPlan] { - private def assignAliases(exprs: Seq[NamedExpression]) = { - exprs.map(_.transformUpWithPruning(_.containsPattern(UNRESOLVED_ALIAS)) { - case u: UnresolvedAlias => resolve(u) - } - ).asInstanceOf[Seq[NamedExpression]] - } - - private[analysis] def resolve(u: UnresolvedAlias): Expression = { - val UnresolvedAlias(child, optGenAliasFunc) = u - child match { - case ne: NamedExpression => ne - case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil) - case e if !e.resolved => u - case g: Generator => MultiAlias(g, Nil) - case c @ Cast(ne: NamedExpression, _, _, _) => Alias(c, ne.name)() - case e: ExtractValue if extractOnly(e) => Alias(e, toPrettySQL(e))() - case e if optGenAliasFunc.isDefined => - Alias(child, optGenAliasFunc.get.apply(e))() - case l: Literal => Alias(l, toPrettySQL(l))() - case e => - val metaForAutoGeneratedAlias = new MetadataBuilder() - .putString(AUTO_GENERATED_ALIAS, "true") - .build() - Alias(e, toPrettySQL(e))(explicitMetadata = Some(metaForAutoGeneratedAlias)) - } - } - - private def extractOnly(e: Expression): Boolean = e match { - case _: ExtractValue => e.children.forall(extractOnly) - case _: Literal => true - case _: Attribute => true - case _ => false - } - - private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = - exprs.exists(_.exists(_.isInstanceOf[UnresolvedAlias])) - - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( - _.containsPattern(UNRESOLVED_ALIAS), ruleId) { - case Aggregate(groups, aggs, child, _) if child.resolved && hasUnresolvedAlias(aggs) => - Aggregate(groups, assignAliases(aggs), child) - - case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child) - if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) => - Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child) + def apply(plan: LogicalPlan): LogicalPlan = + plan.resolveOperatorsUpWithPruning(_.containsPattern(UNRESOLVED_ALIAS), ruleId) { + case Aggregate(groups, aggs, child, _) + if child.resolved && AliasResolution.hasUnresolvedAlias(aggs) => + Aggregate(groups, AliasResolution.assignAliases(aggs), child) + + case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child) + if child.resolved && + groupByOpt.isDefined && + AliasResolution.hasUnresolvedAlias(groupByOpt.get) => + Pivot( + Some(AliasResolution.assignAliases(groupByOpt.get)), + pivotColumn, + pivotValues, + aggregates, + child + ) - case up: Unpivot if up.child.resolved && - (up.ids.exists(hasUnresolvedAlias) || up.values.exists(_.exists(hasUnresolvedAlias))) => - up.copy(ids = up.ids.map(assignAliases), values = up.values.map(_.map(assignAliases))) + case up: Unpivot + if up.child.resolved && + (up.ids.exists(AliasResolution.hasUnresolvedAlias) || up.values.exists( + _.exists(AliasResolution.hasUnresolvedAlias) + )) => + up.copy( + ids = up.ids.map(AliasResolution.assignAliases), + values = up.values.map(_.map(AliasResolution.assignAliases)) + ) - case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => - Project(assignAliases(projectList), child) + case Project(projectList, child) + if child.resolved && AliasResolution.hasUnresolvedAlias(projectList) => + Project(AliasResolution.assignAliases(projectList), child) - case c: CollectMetrics if c.child.resolved && hasUnresolvedAlias(c.metrics) => - c.copy(metrics = assignAliases(c.metrics)) - } + case c: CollectMetrics + if c.child.resolved && AliasResolution.hasUnresolvedAlias(c.metrics) => + c.copy(metrics = AliasResolution.assignAliases(c.metrics)) + } } object ResolveGroupingAnalytics extends Rule[LogicalPlan] {