From ac61a6620e59447c575092bee5d4d7f0af99695c Mon Sep 17 00:00:00 2001 From: donnyzone Date: Tue, 12 Sep 2017 17:28:01 +0800 Subject: [PATCH 1/4] SPARK-21980 --- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../spark/sql/DataFrameAggregateSuite.scala | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1e934d0aa0e5..d10a4088e28c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -314,7 +314,7 @@ class Analyzer( s"grouping columns (${groupByExprs.mkString(",")})") } case e @ Grouping(col: Expression) => - val idx = groupByExprs.indexOf(col) + val idx = groupByExprs.indexWhere(e => resolver(e.toString, col.toString)) if (idx >= 0) { Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), Literal(1)), ByteType), toPrettySQL(e))() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index affe97120c8f..33b859c96aab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -190,6 +190,22 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-21980: References in grouping functions should be indexed with resolver") { + checkAnswer( + courseSales.cube("course", "year") + .agg(grouping("CouRse"), grouping("year")), + Row("Java", 2012, 0, 0) :: + Row("Java", 2013, 0, 0) :: + Row("Java", null, 0, 1) :: + Row("dotNET", 2012, 0, 0) :: + Row("dotNET", 2013, 0, 0) :: + Row("dotNET", null, 0, 1) :: + Row(null, 2012, 1, 0) :: + Row(null, 2013, 1, 0) :: + Row(null, null, 1, 1) :: Nil + ) + } + test("rollup overlapping columns") { checkAnswer( testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"), From b08fd9301cdbd4c1a29d5eb322eacd1cf2ffc546 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Tue, 12 Sep 2017 17:34:53 +0800 Subject: [PATCH 2/4] rename --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d10a4088e28c..2983800143e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -314,7 +314,7 @@ class Analyzer( s"grouping columns (${groupByExprs.mkString(",")})") } case e @ Grouping(col: Expression) => - val idx = groupByExprs.indexWhere(e => resolver(e.toString, col.toString)) + val idx = groupByExprs.indexWhere(x => resolver(x.toString, col.toString)) if (idx >= 0) { Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), Literal(1)), ByteType), toPrettySQL(e))() From e24fdb8fe525263529f457d5e723bb396057ea0a Mon Sep 17 00:00:00 2001 From: donnyzone Date: Wed, 13 Sep 2017 09:49:37 +0800 Subject: [PATCH 3/4] use semanticEquals --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2983800143e6..0880bd66ea4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -314,7 +314,7 @@ class Analyzer( s"grouping columns (${groupByExprs.mkString(",")})") } case e @ Grouping(col: Expression) => - val idx = groupByExprs.indexWhere(x => resolver(x.toString, col.toString)) + val idx = groupByExprs.indexWhere(_.semanticEquals(col)) if (idx >= 0) { Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), Literal(1)), ByteType), toPrettySQL(e))() From 09efc4d9e4127244449da05fa44dc3de0ceb5b05 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Wed, 13 Sep 2017 10:12:42 +0800 Subject: [PATCH 4/4] test name --- .../scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 33b859c96aab..8549eac58ee9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -190,7 +190,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } - test("SPARK-21980: References in grouping functions should be indexed with resolver") { + test("SPARK-21980: References in grouping functions should be indexed with semanticEquals") { checkAnswer( courseSales.cube("course", "year") .agg(grouping("CouRse"), grouping("year")),