diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 5ecb77be5965..9c31a487d37c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -36,7 +36,13 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate { } private lazy val sumDataType = child.dataType match { - case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) + /* + * In case of sum of decimal ( assuming another decimal of same precision and scale) + * Refer : org.apache.spark.sql.catalyst.analysis.DecimalPrecision + * Precision : max(s1, s2) + max(p1 - s1, p2 - s2) + 1 + * Scale : max(s1, s2) + */ + case _ @ DecimalType.Fixed(p, s) => DecimalType.adjustPrecisionScale(s + (p - s) + 1, s) case _ => DoubleType } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index b9c32e789a41..44c863cc9e1f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1192,6 +1192,25 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd } } } + + test("SPARK-25413 Test scale and precision") { + val expected = new java.math.BigDecimal("37800224355780013.7598204253756364") + sql("create table if not exists table1(salary decimal(31,12))") + sql("insert into table1 values(12345678901234510.1234567890123)") + sql("insert into table1 values(12345678901234520.1234567890123)") + sql("insert into table1 values(12345678901234530.1234567890123)") + sql("insert into table1 values(12345678901234560.1234567890123)") + sql("insert into table1 values(22345678901234560.1234567890123)") + sql("insert into table1 values(32345678901234560.1234567890123)") + sql("insert into table1 values(42345678901234560.1234567890123)") + sql("insert into table1 values(52345678901234560.1234567890123)") + sql("insert into table1 values(62345678901234560.1234567890123)") + sql("insert into table1 values(72345678901234560.1234567890123)") + sql("insert into table1 values(82345678901234560.1234567890123)") + assert(sql("select avg(salary)+10 from table1") + .first() + .getAs[java.math.BigDecimal](0).equals(expected)) + } } // for SPARK-2180 test