From b906c8cc541225ac0f7c5ebda725f4bed758eb58 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sun, 3 Feb 2019 20:37:20 +0100 Subject: [PATCH 1/5] [SPARK-26572][SQL] fix aggregate codegen result evaluation --- .../sql/execution/aggregate/HashAggregateExec.scala | 12 ++++++++++-- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 10 ++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 19a47ffc6dd0..b1d9bfdf6c47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -466,10 +466,12 @@ case class HashAggregateExec( val resultVars = bindReferences[Expression]( resultExpressions, inputAttrs).map(_.genCode(ctx)) + val evaluateResultVars = evaluateVariables(resultVars) s""" $evaluateKeyVars $evaluateBufferVars $evaluateAggResults + $evaluateResultVars ${consume(ctx, resultVars)} """ } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { @@ -497,19 +499,25 @@ case class HashAggregateExec( val resultVars = bindReferences[Expression]( resultExpressions, inputAttrs).map(_.genCode(ctx)) + val evaluateResultVars = evaluateVariables(resultVars) s""" $evaluateKeyVars $evaluateResultBufferVars + $evaluateResultVars ${consume(ctx, resultVars)} """ } else { // generate result based on grouping key ctx.INPUT_ROW = keyTerm ctx.currentVars = null - val eval = bindReferences[Expression]( + val resultVars = bindReferences[Expression]( resultExpressions, groupingAttributes).map(_.genCode(ctx)) - consume(ctx, eval) + val evaluateResultVars = evaluateVariables(resultVars) + s""" + $evaluateResultVars + ${consume(ctx, resultVars)} + """ } ctx.addNewFunction(funcName, s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3082e0bb97df..400f1dfe8575 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2110,4 +2110,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(res, Row("1-1", 6, 6)) } } + + test("SPARK-26572: fix aggregate codegen result evaluation") { + val baseTable = Seq((1), (1)).toDF("idx") + val distinctWithId = + baseTable.distinct.withColumn("id", functions.monotonically_increasing_id()) + val res = baseTable.join(distinctWithId, "idx") + .groupBy("id").count().as("count") + .select("count") + checkAnswer(res, Row(2)) + } } From b5d079c59995bee1c3d50d630223dd1ae0edbe68 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sun, 3 Feb 2019 21:05:13 +0100 Subject: [PATCH 2/5] indentation fix Change-Id: Ie07c913fc4586296c8187f0972c19169da25f613 --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 400f1dfe8575..48075599cc4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2114,7 +2114,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-26572: fix aggregate codegen result evaluation") { val baseTable = Seq((1), (1)).toDF("idx") val distinctWithId = - baseTable.distinct.withColumn("id", functions.monotonically_increasing_id()) + baseTable.distinct.withColumn("id", functions.monotonically_increasing_id()) val res = baseTable.join(distinctWithId, "idx") .groupBy("id").count().as("count") .select("count") From 567f8f69f2aa9b713612e159c853446da5e84326 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 5 Feb 2019 15:21:26 +0100 Subject: [PATCH 3/5] incorporate suggested changes --- .../sql/execution/WholeStageCodegenExec.scala | 12 ++++++++ .../aggregate/HashAggregateExec.scala | 12 ++++---- .../org/apache/spark/sql/DataFrameSuite.scala | 28 +++++++++++++++---- 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index d3a93f5eb395..443ce8cc46a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -290,6 +290,18 @@ trait CodegenSupport extends SparkPlan { evaluateVars.toString() } + /** + * Returns source code to evaluate the variables for non-deterministic expressions, and clear the + * code of evaluated variables, to prevent them to be evaluated twice. + */ + protected def evaluateNondeterministicVariables( + attributes: Seq[Attribute], + variables: Seq[ExprCode], + expressions: Seq[NamedExpression]): String = { + val nondeterministicAttrs = expressions.filterNot(_.deterministic).map(_.toAttribute) + evaluateRequiredVariables(attributes, variables, AttributeSet(nondeterministicAttrs)) + } + /** * The subset of inputSet those should be evaluated before this plan. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index b1d9bfdf6c47..3307aa4eca22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -466,12 +466,13 @@ case class HashAggregateExec( val resultVars = bindReferences[Expression]( resultExpressions, inputAttrs).map(_.genCode(ctx)) - val evaluateResultVars = evaluateVariables(resultVars) + val evaluateNondeterministicAggResults = + evaluateNondeterministicVariables(output, resultVars, resultExpressions) s""" $evaluateKeyVars $evaluateBufferVars $evaluateAggResults - $evaluateResultVars + $evaluateNondeterministicAggResults ${consume(ctx, resultVars)} """ } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { @@ -499,11 +500,9 @@ case class HashAggregateExec( val resultVars = bindReferences[Expression]( resultExpressions, inputAttrs).map(_.genCode(ctx)) - val evaluateResultVars = evaluateVariables(resultVars) s""" $evaluateKeyVars $evaluateResultBufferVars - $evaluateResultVars ${consume(ctx, resultVars)} """ } else { @@ -513,9 +512,10 @@ case class HashAggregateExec( val resultVars = bindReferences[Expression]( resultExpressions, groupingAttributes).map(_.genCode(ctx)) - val evaluateResultVars = evaluateVariables(resultVars) + val evaluateNondeterministicAggResults = + evaluateNondeterministicVariables(output, resultVars, resultExpressions) s""" - $evaluateResultVars + $evaluateNondeterministicAggResults ${consume(ctx, resultVars)} """ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 48075599cc4a..3fa578e43c2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -30,11 +30,13 @@ import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Uuid +import org.apache.spark.sql.catalyst.expressions.aggregate.Final import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} @@ -2113,11 +2115,25 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-26572: fix aggregate codegen result evaluation") { val baseTable = Seq((1), (1)).toDF("idx") - val distinctWithId = - baseTable.distinct.withColumn("id", functions.monotonically_increasing_id()) - val res = baseTable.join(distinctWithId, "idx") - .groupBy("id").count().as("count") - .select("count") - checkAnswer(res, Row(2)) + + // BroadcastHashJoinExec with a HashAggregateExec child containing no aggregate expressions + val distinctWithId = baseTable.distinct().withColumn("id", monotonically_increasing_id()) + .join(baseTable, "idx") + assert(distinctWithId.queryExecution.executedPlan.collectFirst { + case BroadcastHashJoinExec(_, _, _, _, _, HashAggregateExec(_, _, Seq(), _, _, _, _), _) => + true + }.isDefined) + checkAnswer(distinctWithId, Seq(Row(1, 25769803776L), Row(1, 25769803776L))) + + // BroadcastHashJoinExec with a HashAggregateExec child containing a Final mode aggregate + // expression + val groupByWithId = + baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id()) + .join(baseTable, "idx") + assert(groupByWithId.queryExecution.executedPlan.collectFirst { + case BroadcastHashJoinExec(_, _, _, _, _, HashAggregateExec(_, _, ae, _, _, _, _), _) + if ae.exists(_.mode == Final) => true + }.isDefined) + checkAnswer(groupByWithId, Seq(Row(1, 2, 25769803776L), Row(1, 2, 25769803776L))) } } From 5ae9add508a9341c1ca781ffd54598d560d16ed1 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 13 Feb 2019 11:24:25 +0100 Subject: [PATCH 4/5] fix review findings --- .../org/apache/spark/sql/DataFrameSuite.scala | 26 ---------------- .../execution/WholeStageCodegenSuite.scala | 30 ++++++++++++++++++- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3fa578e43c2b..3082e0bb97df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -30,13 +30,11 @@ import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Uuid -import org.apache.spark.sql.catalyst.expressions.aggregate.Final import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} @@ -2112,28 +2110,4 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(res, Row("1-1", 6, 6)) } } - - test("SPARK-26572: fix aggregate codegen result evaluation") { - val baseTable = Seq((1), (1)).toDF("idx") - - // BroadcastHashJoinExec with a HashAggregateExec child containing no aggregate expressions - val distinctWithId = baseTable.distinct().withColumn("id", monotonically_increasing_id()) - .join(baseTable, "idx") - assert(distinctWithId.queryExecution.executedPlan.collectFirst { - case BroadcastHashJoinExec(_, _, _, _, _, HashAggregateExec(_, _, Seq(), _, _, _, _), _) => - true - }.isDefined) - checkAnswer(distinctWithId, Seq(Row(1, 25769803776L), Row(1, 25769803776L))) - - // BroadcastHashJoinExec with a HashAggregateExec child containing a Final mode aggregate - // expression - val groupByWithId = - baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id()) - .join(baseTable, "idx") - assert(groupByWithId.queryExecution.executedPlan.collectFirst { - case BroadcastHashJoinExec(_, _, _, _, _, HashAggregateExec(_, _, ae, _, _, _, _), _) - if ae.exists(_.mode == Final) => true - }.isDefined) - checkAnswer(groupByWithId, Seq(Row(1, 2, 25769803776L), Row(1, 2, 25769803776L))) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index e03f08417162..3c9a0908147a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed -import org.apache.spark.sql.functions.{avg, broadcast, col, lit, max} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -339,4 +339,32 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Seq(Row(1, 3), Row(2, 3))) } + + test("SPARK-26572: evaluate non-deterministic expressions for aggregate results") { + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val baseTable = Seq(1, 1).toDF("idx") + + // BroadcastHashJoinExec with a HashAggregateExec child containing no aggregate expressions + val distinctWithId = baseTable.distinct().withColumn("id", monotonically_increasing_id()) + .join(baseTable, "idx") + assert(distinctWithId.queryExecution.executedPlan.collectFirst { + case WholeStageCodegenExec( + ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _))) => true + }.isDefined) + checkAnswer(distinctWithId, Seq(Row(1, 0), Row(1, 0))) + + // BroadcastHashJoinExec with a HashAggregateExec child containing a Final mode aggregate + // expression + val groupByWithId = + baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id()) + .join(baseTable, "idx") + assert(groupByWithId.queryExecution.executedPlan.collectFirst { + case WholeStageCodegenExec( + ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _))) => true + }.isDefined) + checkAnswer(groupByWithId, Seq(Row(1, 2, 0), Row(1, 2, 0))) + } + } } From af861d554bbcc9785287e1ee74ec9d602fea4705 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 14 Feb 2019 11:02:15 +0100 Subject: [PATCH 5/5] fix var name Change-Id: I1a2c52e7ba30a186517d91568093da813f201d1f --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 3307aa4eca22..17cc7fde42bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -466,13 +466,13 @@ case class HashAggregateExec( val resultVars = bindReferences[Expression]( resultExpressions, inputAttrs).map(_.genCode(ctx)) - val evaluateNondeterministicAggResults = + val evaluateNondeterministicResults = evaluateNondeterministicVariables(output, resultVars, resultExpressions) s""" $evaluateKeyVars $evaluateBufferVars $evaluateAggResults - $evaluateNondeterministicAggResults + $evaluateNondeterministicResults ${consume(ctx, resultVars)} """ } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { @@ -512,10 +512,10 @@ case class HashAggregateExec( val resultVars = bindReferences[Expression]( resultExpressions, groupingAttributes).map(_.genCode(ctx)) - val evaluateNondeterministicAggResults = + val evaluateNondeterministicResults = evaluateNondeterministicVariables(output, resultVars, resultExpressions) s""" - $evaluateNondeterministicAggResults + $evaluateNondeterministicResults ${consume(ctx, resultVars)} """ }