From 1153f752b0c3e9cb27250eb842c83a3f84e00362 Mon Sep 17 00:00:00 2001 From: egraldlo Date: Thu, 5 Jun 2014 16:31:43 +0800 Subject: [PATCH 1/4] do best to avoid overflowing in function avg(). --- .../spark/sql/catalyst/expressions/aggregates.scala | 8 ++++---- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 6 ++++++ .../test/scala/org/apache/spark/sql/TestData.scala | 11 +++++++++++ 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 01947273b6ccc..47b4e643395eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -214,10 +214,10 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN override def toString = s"AVG($child)" override def asPartial: SplitEvaluation = { - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() - val castedSum = Cast(Sum(partialSum.toAttribute), dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), dataType) + val partialSum = Alias(Sum(Cast(child, dataType)), "PartialSum")() + val partialCount = Alias(Cast(Count(child), dataType), "PartialCount")() + val castedSum = Sum(partialSum.toAttribute) + val castedCount = Sum(partialCount.toAttribute) SplitEvaluation( Divide(castedSum, castedCount), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 95860e6683f67..c47c243bc0f71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -123,6 +123,12 @@ class SQLQuerySuite extends QueryTest { 2.0) } + test("average overflow test") { + checkAnswer( + sql("SELECT AVG(a),b FROM testData1 group by b"), + Seq((2147483645.0,1),(2.0,2))) + } + test("count") { checkAnswer( sql("SELECT COUNT(*) FROM testData2"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 944f520e43515..ec494d49a2fb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -30,6 +30,17 @@ object TestData { (1 to 100).map(i => TestData(i, i.toString))) testData.registerAsTable("testData") + case class TestData1(a: Int, b: Int) + val testData1: SchemaRDD = + TestSQLContext.sparkContext.parallelize( + TestData1(2147483644, 1) :: + TestData1(1, 2) :: + TestData1(2147483645, 1) :: + TestData1(2, 2) :: + TestData1(2147483646, 1) :: + TestData1(3, 2) :: Nil) + testData1.registerAsTable("testData1") + case class TestData2(a: Int, b: Int) val testData2: SchemaRDD = TestSQLContext.sparkContext.parallelize( From d414cd70a719e6ace60b588da65457300970a8f1 Mon Sep 17 00:00:00 2001 From: egraldlo Date: Thu, 5 Jun 2014 16:42:36 +0800 Subject: [PATCH 2/4] fommatting issues --- sql/core/src/test/scala/org/apache/spark/sql/TestData.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index ec494d49a2fb8..b8fe3c3856325 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -32,8 +32,8 @@ object TestData { case class TestData1(a: Int, b: Int) val testData1: SchemaRDD = - TestSQLContext.sparkContext.parallelize( - TestData1(2147483644, 1) :: + TestSQLContext.sparkContext.parallelize( + TestData1(2147483644, 1) :: TestData1(1, 2) :: TestData1(2147483645, 1) :: TestData1(2, 2) :: From 762aeaf149d7eca2e42091072f486c450b80700b Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 10 Jun 2014 01:00:56 -0700 Subject: [PATCH 3/4] Remove unneeded rule. More descriptive name for test table. --- .../sql/catalyst/expressions/aggregates.scala | 8 ++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../scala/org/apache/spark/sql/TestData.scala | 18 +++++++++--------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 47b4e643395eb..01947273b6ccc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -214,10 +214,10 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN override def toString = s"AVG($child)" override def asPartial: SplitEvaluation = { - val partialSum = Alias(Sum(Cast(child, dataType)), "PartialSum")() - val partialCount = Alias(Cast(Count(child), dataType), "PartialCount")() - val castedSum = Sum(partialSum.toAttribute) - val castedCount = Sum(partialCount.toAttribute) + val partialSum = Alias(Sum(child), "PartialSum")() + val partialCount = Alias(Count(child), "PartialCount")() + val castedSum = Cast(Sum(partialSum.toAttribute), dataType) + val castedCount = Cast(Sum(partialCount.toAttribute), dataType) SplitEvaluation( Divide(castedSum, castedCount), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c47c243bc0f71..89277cec3248b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -125,7 +125,7 @@ class SQLQuerySuite extends QueryTest { test("average overflow test") { checkAnswer( - sql("SELECT AVG(a),b FROM testData1 group by b"), + sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), Seq((2147483645.0,1),(2.0,2))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index b8fe3c3856325..8b94e3c8e4d37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -30,16 +30,16 @@ object TestData { (1 to 100).map(i => TestData(i, i.toString))) testData.registerAsTable("testData") - case class TestData1(a: Int, b: Int) - val testData1: SchemaRDD = + case class LargeAndSmallInts(a: Int, b: Int) + val largeAndSmallInts: SchemaRDD = TestSQLContext.sparkContext.parallelize( - TestData1(2147483644, 1) :: - TestData1(1, 2) :: - TestData1(2147483645, 1) :: - TestData1(2, 2) :: - TestData1(2147483646, 1) :: - TestData1(3, 2) :: Nil) - testData1.registerAsTable("testData1") + LargeAndSmallInts(2147483644, 1) :: + LargeAndSmallInts(1, 2) :: + LargeAndSmallInts(2147483645, 1) :: + LargeAndSmallInts(2, 2) :: + LargeAndSmallInts(2147483646, 1) :: + LargeAndSmallInts(3, 2) :: Nil) + largeAndSmallInts.registerAsTable("largeAndSmallInts") case class TestData2(a: Int, b: Int) val testData2: SchemaRDD = From e228c5e1d9973bab6a951c9b040014e2ddb45de9 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 10 Jun 2014 01:08:12 -0700 Subject: [PATCH 4/4] Remove "test". --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 89277cec3248b..1b9335b70f4ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -123,7 +123,7 @@ class SQLQuerySuite extends QueryTest { 2.0) } - test("average overflow test") { + test("average overflow") { checkAnswer( sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), Seq((2147483645.0,1),(2.0,2)))