From d676b6277a682894d409e314e64ece7857a97841 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 25 Apr 2018 18:14:55 +0200 Subject: [PATCH] [SPARK-24051][SQL] Replace Aliases with the same exprId --- .../sql/catalyst/analysis/Analyzer.scala | 78 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 12 +++ 2 files changed, 90 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e821e96522f7..670cf20dd1ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random @@ -36,6 +37,7 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ + /** * A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]]. * Used for testing when all relations are already filled in and the analyzer needs only @@ -145,6 +147,8 @@ class Analyzer( ResolveHints.RemoveAllHints), Batch("Simple Sanity Check", Once, LookupFunctions), + Batch("DeduplicateAliases", Once, + DeduplicateAliases), Batch("Substitution", fixedPoint, CTESubstitution, WindowsSubstitution, @@ -284,6 +288,80 @@ class Analyzer( } } + /** + * Replaces [[Alias]] with the same exprId but different references with [[Alias]] having + * different exprIds. This is a rare situation which can cause incorrect results. + */ + object DeduplicateAliases extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val allAliases = collectAllAliasesInPlan(plan) + val dupAliases = allAliases.groupBy(_.exprId).collect { + case (eId, aliases) if containsDifferentAliases(aliases) => eId + }.toSeq + if (dupAliases.nonEmpty) { + val exprIdsDictionary = mutable.HashMap[ExprId, ExprId]() + resolveConflictingAliases(plan, dupAliases, exprIdsDictionary) + } else { + plan + } + } + + def containsDifferentAliases(aliases: Seq[Alias]): Boolean = { + aliases.exists(a1 => aliases.exists(a2 => !a1.fastEquals(a2))) + } + + def collectAllAliasesInPlan(plan: LogicalPlan): Seq[Alias] = { + plan.flatMap { + case Project(projectList, _) => projectList.collect { case a: Alias => a } + case AnalysisBarrier(child) => collectAllAliasesInPlan(child) + case _ => Nil + } + } + + def containsExprIds( + projectList: Seq[NamedExpression], + exprIds: Seq[ExprId]): Boolean = { + projectList.count { + case a: Alias if exprIds.contains(a.exprId) => true + case a: AttributeReference if exprIds.contains(a.exprId) => true + case _ => false + } > 0 + } + + def renewConflictingAliases( + exprs: Seq[NamedExpression], + exprIds: Seq[ExprId], + exprIdsDictionary: mutable.HashMap[ExprId, ExprId]): Seq[NamedExpression] = { + exprs.map { + case a: Alias if exprIds.contains(a.exprId) => + val newAlias = Alias(a.child, a.name)() + // update the map with the new id to replace + // since we are in a transformUp, all the parent nodes will see the updated map + exprIdsDictionary(a.exprId) = newAlias.exprId + newAlias + case a: AttributeReference if exprIds.contains(a.exprId) => + // replace with the new id + a.withExprId(exprIdsDictionary(a.exprId)) + case other => other + } + } + + def resolveConflictingAliases( + plan: LogicalPlan, + dupAliases: Seq[ExprId], + exprIdsDictionary: mutable.HashMap[ExprId, ExprId]): LogicalPlan = { + plan.transformUp { + case p @ Project(projectList, _) if containsExprIds(projectList, dupAliases) => + p.copy(renewConflictingAliases(projectList, dupAliases, exprIdsDictionary)) + case a @ Aggregate(_, aggs, _) if containsExprIds(aggs, dupAliases) => + a.copy(aggregateExpressions = + renewConflictingAliases(aggs, dupAliases, exprIdsDictionary)) + case AnalysisBarrier(child) => + AnalysisBarrier(resolveConflictingAliases(child, dupAliases, exprIdsDictionary)) + } + } + } + object ResolveGroupingAnalytics extends Rule[LogicalPlan] { /* * GROUP BY a, b, c WITH ROLLUP diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 60e84e6ee750..d874afca4760 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Unio import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} @@ -2265,4 +2266,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.range(1).select($"id", new Column(Uuid())) checkAnswer(df, df.collect()) } + + test("SPARK-24051: using the same alias can produce incorrect result") { + val ds1 = Seq((1, 42), (2, 99)).toDF("a", "b") + val ds2 = Seq(3).toDF("a").withColumn("b", lit(0)) + + val cols = Seq(col("a"), col("b").alias("b"), + count(lit(1)).over(Window.partitionBy()).alias("n")) + + val df = ds1.select(cols: _*).union(ds2.select(cols: _*)) + checkAnswer(df, Row(1, 42, 2) :: Row(2, 99, 2) :: Row(3, 0, 1) :: Nil) + } }