diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 1f217390518a6..6082c58e2c53a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -357,6 +357,7 @@ object TypeCoercion { val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) + case Abs(e @ StringType()) => Abs(Cast(e, DoubleType)) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 744057b7c5f4c..2239bf815de71 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -57,7 +57,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { test("check types for unary arithmetic") { assertError(UnaryMinus('stringField), "(numeric or calendarinterval) type") - assertError(Abs('stringField), "requires numeric type") assertError(BitwiseNot('stringField), "requires integral type") } diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index a8de23e73892c..a1e8a32ed8f66 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -85,3 +85,6 @@ select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, nu select BIT_LENGTH('abc'); select CHAR_LENGTH('abc'); select OCTET_LENGTH('abc'); + +-- abs +select abs(-3.13), abs('-2.19'); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 85ee10b4d274f..eac3080bec67d 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 54 +-- Number of queries: 55 -- !query 0 @@ -444,3 +444,11 @@ select OCTET_LENGTH('abc') struct -- !query 53 output 3 + + +-- !query 54 +select abs(-3.13), abs('-2.19') +-- !query 54 schema +struct +-- !query 54 output +3.13 2.19