diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 8b69a4703696..344f89ba9f31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -262,6 +262,7 @@ object FunctionRegistry { expression[Tan]("tan"), expression[Cot]("cot"), expression[Tanh]("tanh"), + expression[Truncate]("truncate"), expression[Add]("+"), expression[Subtract]("-"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index c2e1720259b5..942a4d0f99a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1245,3 +1245,27 @@ case class BRound(child: Expression, scale: Expression) with Serializable with ImplicitCastInputTypes { def this(child: Expression) = this(child, Literal(0)) } + +/** + * The number truncated to scale decimal places. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(number, scale) - Returns number truncated to scale decimal places. " + + "If scale is omitted, then number is truncated to 0 places. " + + "scale can be negative to truncate (make zero) scale digits left of the decimal point.", + examples = """ + Examples: + > SELECT _FUNC_(1234567891.1234567891, 4); + 1234567891.1234 + > SELECT _FUNC_(1234567891.1234567891, -4); + 1234560000 + > SELECT _FUNC_(1234567891.1234567891); + 1234567891 + """) +// scalastyle:on line.size.limit +case class Truncate(child: Expression, scale: Expression) + extends RoundBase(child, scale, BigDecimal.RoundingMode.DOWN, "ROUND_DOWN") + with Serializable with ImplicitCastInputTypes { + def this(child: Expression) = this(child, Literal(0)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 9eed2eb20204..eea471ca8cd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -287,6 +287,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) { longVal += (if (droppedDigits < 0) -1L else 1L) } + case ROUND_DOWN => case _ => sys.error(s"Not supported rounding mode: $roundMode") } @@ -413,6 +414,7 @@ object Decimal { val ROUND_HALF_EVEN = BigDecimal.RoundingMode.HALF_EVEN val ROUND_CEILING = BigDecimal.RoundingMode.CEILING val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR + val ROUND_DOWN = BigDecimal.RoundingMode.DOWN /** Maximum number of decimal digits an Int can represent */ val MAX_INT_DIGITS = 9 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 3a094079380f..b374424b711e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -644,4 +644,59 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BRound(-0.35, 1), -0.4) checkEvaluation(BRound(-35, -1), -40) } + + test("Truncate number") { + assert(Truncate(Literal.create(123.123, DoubleType), + NonFoldableLiteral.create(1, IntegerType)).checkInputDataTypes().isFailure) + assert(Truncate(Literal.create(123.123, DoubleType), + Literal.create(1, IntegerType)).checkInputDataTypes().isSuccess) + + def testDouble(input: Any, scale: Any, expected: Any): Unit = { + checkEvaluation(Truncate(Literal.create(input, DoubleType), + Literal.create(scale, IntegerType)), + expected) + } + + def testFloat(input: Any, scale: Any, expected: Any): Unit = { + checkEvaluation(Truncate(Literal.create(input, FloatType), + Literal.create(scale, IntegerType)), + expected) + } + + def testDecimal(input: Any, scale: Any, expected: Any): Unit = { + checkEvaluation(Truncate(Literal.create(input, DecimalType.DoubleDecimal), + Literal.create(scale, IntegerType)), + expected) + } + + testDouble(1234567891.1234567891D, 4, 1234567891.1234D) + testDouble(1234567891.1234567891D, -4, 1234560000D) + testDouble(1234567891.1234567891D, 0, 1234567891D) + testDouble(0.123D, -1, 0D) + testDouble(0.123D, 0, 0D) + testDouble(null, null, null) + testDouble(null, 0, null) + testDouble(1D, null, null) + testDouble(-1234567891.1234567891D, 4, -1234567891.1234D) + + testFloat(1234567891.1234567891F, 4, 1234567891.1234F) + testFloat(1234567891.1234567891F, -4, 1234560000F) + testFloat(1234567891.1234567891F, 0, 1234567891F) + testFloat(0.123F, -1, 0F) + testFloat(0.123F, 0, 0F) + testFloat(null, null, null) + testFloat(null, 0, null) + testFloat(1F, null, null) + testFloat(-1234567891.1234567891F, 4, -1234567891.1234F) + + testDecimal(Decimal(1234567891.1234567891), 4, Decimal(1234567891.1234)) + testDecimal(Decimal(-1234567891.1234567891), 4, Decimal(-1234567891.1234)) + testDecimal(Decimal(1234567891.1234567891), -4, Decimal(1234560000)) + testDecimal(Decimal(1234567891.1234567891), 0, Decimal(1234567891)) + testDecimal(Decimal(0.123), -1, Decimal(0)) + testDecimal(Decimal(0.123), 0, Decimal(0)) + testDecimal(null, null, null) + testDecimal(null, 0, null) + testDecimal(Decimal(1), null, null) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 10de90c6a44c..962b7049fe62 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -204,7 +204,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { } test("changePrecision/toPrecision on compact decimal should respect rounding mode") { - Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode => + Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN, ROUND_DOWN).foreach { mode => Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n => Seq("", "-").foreach { sign => val bd = BigDecimal(sign + n) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 10b67d7a1ca5..aed0faefd450 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2214,6 +2214,26 @@ object functions { */ def radians(columnName: String): Column = radians(Column(columnName)) + /** + * Returns the value of the column `e` truncated to 0 places. + * + * @group math_funcs + * @since 2.5.0 + */ + def truncate(e: Column): Column = truncate(e, 0) + + /** + * Returns the value of column `e` truncated to the unit specified by the scale. + * If scale is omitted, then the value of column `e` is truncated to 0 places. + * Scale can be negative to truncate (make zero) scale digits left of the decimal point. + * + * @group math_funcs + * @since 2.5.0 + */ + def truncate(e: Column, scale: Int): Column = withExpr { + Truncate(e.expr, Literal(scale)) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Misc functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 37f9cd44da7f..31d2a77240a2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -92,3 +92,12 @@ select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11); -- pmod select pmod(-7, 2), pmod(0, 2), pmod(7, 0), pmod(7, null), pmod(null, 2), pmod(null, null); select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint), cast(0 as smallint)); + +-- truncate +select truncate(cast(1234567891.1234567891 as double), -4), truncate(cast(1234567891.1234567891 as double), 0), truncate(cast(1234567891.1234567891 as double), 4); +select truncate(cast(1234567891.1234567891 as float), -4), truncate(cast(1234567891.1234567891 as float), 0), truncate(cast(1234567891.1234567891 as float), 4); +select truncate(cast(1234567891.1234567891 as decimal), -4), truncate(cast(1234567891.1234567891 as decimal), 0), truncate(cast(1234567891.1234567891 as decimal), 4); +select truncate(cast(1234567891.1234567891 as long), -4), truncate(cast(1234567891.1234567891 as long), 0), truncate(cast(1234567891.1234567891 as long), 4); +select truncate(cast(1234567891.1234567891 as long), 9.03); +select truncate(cast(1234567891.1234567891 as double)), truncate(cast(1234567891.1234567891 as float)), truncate(cast(1234567891.1234567891 as decimal)); +select truncate(cast(-1234567891.1234567891 as double), -4), truncate(cast(-1234567891.1234567891 as double), 4); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index fd1d0db9e3f7..46ade81086e6 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 55 +-- Number of queries: 62 -- !query 0 @@ -452,3 +452,59 @@ select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint) struct -- !query 54 output NULL NULL + + +-- !query 55 +select truncate(cast(1234567891.1234567891 as double), -4), truncate(cast(1234567891.1234567891 as double), 0), truncate(cast(1234567891.1234567891 as double), 4) +-- !query 55 schema +struct +-- !query 55 output +1.23456E9 1.234567891E9 1.2345678911234E9 + + +-- !query 56 +select truncate(cast(1234567891.1234567891 as float), -4), truncate(cast(1234567891.1234567891 as float), 0), truncate(cast(1234567891.1234567891 as float), 4) +-- !query 56 schema +struct +-- !query 56 output +1.23456E9 1.23456794E9 1.23456794E9 + + +-- !query 57 +select truncate(cast(1234567891.1234567891 as decimal), -4), truncate(cast(1234567891.1234567891 as decimal), 0), truncate(cast(1234567891.1234567891 as decimal), 4) +-- !query 57 schema +struct +-- !query 57 output +1234560000 1234567891 1234567891 + + +-- !query 58 +select truncate(cast(1234567891.1234567891 as long), -4), truncate(cast(1234567891.1234567891 as long), 0), truncate(cast(1234567891.1234567891 as long), 4) +-- !query 58 schema +struct +-- !query 58 output +1234560000 1234567891 1234567891 + + +-- !query 59 +select truncate(cast(1234567891.1234567891 as long), 9.03) +-- !query 59 schema +struct +-- !query 59 output +1234567891 + + +-- !query 60 +select truncate(cast(1234567891.1234567891 as double)), truncate(cast(1234567891.1234567891 as float)), truncate(cast(1234567891.1234567891 as decimal)) +-- !query 60 schema +struct +-- !query 60 output +1.234567891E9 1.23456794E9 1234567891 + + +-- !query 61 +select truncate(cast(-1234567891.1234567891 as double), -4), truncate(cast(-1234567891.1234567891 as double), 4) +-- !query 61 schema +struct +-- !query 61 output +-1.23456E9 -1.2345678911234E9