From c4d16796f3ecec259e3e1af4afa9c13b4c5a142b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 14 May 2017 09:24:05 +0800 Subject: [PATCH] partial aggregate should behave correctly for sameResult --- .../expressions/aggregate/interfaces.scala | 14 ++++++- .../spark/sql/catalyst/plans/QueryPlan.scala | 4 +- .../spark/sql/execution/SameResultSuite.scala | 39 +++++++++++++++++++ 3 files changed, 53 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index f3fd58bc98ef6..c43a449749877 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -105,12 +105,22 @@ case class AggregateExpression( } // We compute the same thing regardless of our final result. - override lazy val canonicalized: Expression = + override lazy val canonicalized: Expression = { + val normalizedAggFunc = mode match { + // For PartialMerge or Final mode, the input to the `aggregateFunction` is aggregate buffers, + // and the actual children of `aggregateFunction` is not used, here we normalize the expr id. + case PartialMerge | Final => aggregateFunction.transform { + case a: AttributeReference => a.withExprId(ExprId(0)) + } + case Partial | Complete => aggregateFunction + } + AggregateExpression( - aggregateFunction.canonicalized.asInstanceOf[AggregateFunction], + normalizedAggFunc.canonicalized.asInstanceOf[AggregateFunction], mode, isDistinct, ExprId(0)) + } override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index a5761703fd655..2dbe22d0625b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -274,7 +274,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT def recursiveTransform(arg: Any): AnyRef = arg match { case e: Expression => transformExpression(e) - case Some(e: Expression) => Some(transformExpression(e)) + case Some(value) => Some(recursiveTransform(value)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map(recursiveTransform) @@ -309,7 +309,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT productIterator.flatMap { case e: Expression => e :: Nil case Some(e: Expression) => e :: Nil - case seq: Traversable[_] => seqToExpressions(seq) + case s: Some[_] => seqToExpressions(s.toSeq) case other => Nil }.toSeq } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala new file mode 100644 index 0000000000000..3f7746c913518 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +/** + * Tests for the sameResult function for [[SparkPlan]]s. + */ +class SameResultSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("SPARK-20725: partial aggregate should behave correctly for sameResult") { + val df1 = spark.range(10).agg(sum($"id")) + val df2 = spark.range(10).agg(sum($"id")) + assert(df1.queryExecution.executedPlan.sameResult(df2.queryExecution.executedPlan)) + + val df3 = spark.range(10).agg(sumDistinct($"id")) + val df4 = spark.range(10).agg(sumDistinct($"id")) + assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan)) + } +}