diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0d40368c9cd6..4d4db3cf6121 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1083,15 +1083,24 @@ def to_timestamp(col, format=None): @since(1.5) def trunc(date, format): """ - Returns date truncated to the unit specified by the format. + Returns date truncated to the unit specified by the format or + numeric truncated by specified decimal places. - :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' + :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' for date + and any int for numeric. >>> df = spark.createDataFrame([('1997-02-28',)], ['d']) >>> df.select(trunc(df.d, 'year').alias('year')).collect() [Row(year=datetime.date(1997, 1, 1))] >>> df.select(trunc(df.d, 'mon').alias('month')).collect() [Row(month=datetime.date(1997, 2, 1))] + >>> df = spark.createDataFrame([(1234567891.1234567891,)], ['d']) + >>> df.select(trunc(df.d, 4).alias('positive')).collect() + [Row(positive=1234567891.1234)] + >>> df.select(trunc(df.d, -4).alias('negative')).collect() + [Row(negative=1234560000.0)] + >>> df.select(trunc(df.d, 0).alias('zero')).collect() + [Row(zero=1234567891.0)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.trunc(_to_java_column(date), format)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 11538bd31b4f..fc602bbc6083 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -391,7 +391,6 @@ object FunctionRegistry { expression[ParseToDate]("to_date"), expression[ToUnixTimestamp]("to_unix_timestamp"), expression[ToUTCTimestamp]("to_utc_timestamp"), - expression[TruncDate]("trunc"), expression[UnixTimestamp]("unix_timestamp"), expression[DayOfWeek]("dayofweek"), expression[WeekOfYear]("weekofyear"), @@ -426,6 +425,7 @@ object FunctionRegistry { expression[CurrentDatabase]("current_database"), expression[CallMethodViaReflection]("reflect"), expression[CallMethodViaReflection]("java_method"), + expression[Trunc]("trunc"), // grouping sets expression[Cube]("cube"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index eaf878888821..fd82cec41241 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1295,87 +1295,6 @@ case class ParseToTimestamp(left: Expression, format: Option[Expression], child: override def dataType: DataType = TimestampType } -/** - * Returns date truncated to the unit specified by the format. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.", - examples = """ - Examples: - > SELECT _FUNC_('2009-02-12', 'MM'); - 2009-02-01 - > SELECT _FUNC_('2015-10-27', 'YEAR'); - 2015-01-01 - """, - since = "1.5.0") -// scalastyle:on line.size.limit -case class TruncDate(date: Expression, format: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - override def left: Expression = date - override def right: Expression = format - - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) - override def dataType: DataType = DateType - override def nullable: Boolean = true - override def prettyName: String = "trunc" - - private lazy val truncLevel: Int = - DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) - - override def eval(input: InternalRow): Any = { - val level = if (format.foldable) { - truncLevel - } else { - DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) - } - if (level == -1) { - // unknown format - null - } else { - val d = date.eval(input) - if (d == null) { - null - } else { - DateTimeUtils.truncDate(d.asInstanceOf[Int], level) - } - } - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - - if (format.foldable) { - if (truncLevel == -1) { - ev.copy(code = s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") - } else { - val d = date.genCode(ctx) - ev.copy(code = s""" - ${d.code} - boolean ${ev.isNull} = ${d.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = $dtu.truncDate(${d.value}, $truncLevel); - }""") - } - } else { - nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { - val form = ctx.freshName("form") - s""" - int $form = $dtu.parseTruncLevel($fmt); - if ($form == -1) { - ${ev.isNull} = true; - } else { - ${ev.value} = $dtu.truncDate($dateVal, $form); - } - """ - }) - } - } -} - /** * Returns the number of days from startDate to endDate. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index b86e271fe295..778079c7d65c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, MathUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -132,3 +133,214 @@ case class Uuid() extends LeafExpression { s"UTF8String.fromString(java.util.UUID.randomUUID().toString());", isNull = "false") } } + +/** + * Returns date truncated to the unit specified by the format or + * numeric truncated to scale decimal places. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(data[, trunc_param]) - Returns `data` truncated by the format model `trunc_param`. + If `data` is date/timestamp/string type, returns `data` with the time portion of the day truncated to the unit specified by the format model `trunc_param`. If `trunc_param` is omitted, then the default `trunc_param` is 'MM'. + If `data` is decimal/double type, returns `data` truncated to `trunc_param` decimal places. If `trunc_param` is omitted, then the default `trunc_param` is 0. + """, + extended = """ + Examples: + > SELECT _FUNC_('2009-02-12', 'MM'); + 2009-02-01. + > SELECT _FUNC_('2015-10-27', 'YEAR'); + 2015-01-01 + > SELECT _FUNC_('1989-03-13'); + 1989-03-01 + > SELECT _FUNC_(1234567891.1234567891, 4); + 1234567891.1234 + > SELECT _FUNC_(1234567891.1234567891, -4); + 1234560000 + > SELECT _FUNC_(1234567891.1234567891); + 1234567891 + """) +// scalastyle:on line.size.limit +case class Trunc(data: Expression, truncExpr: Expression) + extends BinaryExpression with ExpectsInputTypes { + + def this(data: Expression) = { + this(data, Literal( + if (data.dataType.isInstanceOf[DateType] || + data.dataType.isInstanceOf[TimestampType] || + data.dataType.isInstanceOf[StringType]) { + "MM" + } else { + 0 + }) + ) + } + + override def left: Expression = data + override def right: Expression = truncExpr + + private val isTruncNumber = truncExpr.dataType.isInstanceOf[IntegerType] + private val isTruncDate = truncExpr.dataType.isInstanceOf[StringType] + + override def dataType: DataType = if (isTruncDate) DateType else data.dataType + + override def inputTypes: Seq[AbstractDataType] = data.dataType match { + case NullType => + Seq(dataType, TypeCollection(StringType, IntegerType)) + case DateType | TimestampType | StringType => + Seq(TypeCollection(DateType, TimestampType, StringType), StringType) + case DoubleType | DecimalType.Fixed(_, _) => + Seq(TypeCollection(DoubleType, DecimalType), IntegerType) + case _ => + Seq(TypeCollection(DateType, StringType, TimestampType, DoubleType, DecimalType), + TypeCollection(StringType, IntegerType)) + } + + override def nullable: Boolean = true + + override def prettyName: String = "trunc" + + + private lazy val truncFormat: Int = if (isTruncNumber) { + truncExpr.eval().asInstanceOf[Int] + } else if (isTruncDate) { + DateTimeUtils.parseTruncLevel(truncExpr.eval().asInstanceOf[UTF8String]) + } else { + 0 + } + + override def eval(input: InternalRow): Any = { + val d = data.eval(input) + val truncParam = truncExpr.eval() + if (null == d || null == truncParam) { + null + } else { + if (isTruncNumber) { + val scale = if (truncExpr.foldable) truncFormat else truncExpr.eval().asInstanceOf[Int] + data.dataType match { + case DoubleType => MathUtils.trunc(d.asInstanceOf[Double], scale) + case DecimalType.Fixed(_, _) => + MathUtils.trunc(d.asInstanceOf[Decimal].toJavaBigDecimal, scale) + } + } else if (isTruncDate) { + val level = if (truncExpr.foldable) { + truncFormat + } else { + DateTimeUtils.parseTruncLevel(truncExpr.eval().asInstanceOf[UTF8String]) + } + if (level == -1) { + // unknown format + null + } else { + data.dataType match { + case DateType => DateTimeUtils.truncDate(d.asInstanceOf[Int], level) + case TimestampType => + val ts = DateTimeUtils.timestampToString(d.asInstanceOf[Long]) + val dt = DateTimeUtils.stringToDate(UTF8String.fromString(ts)) + if (dt.isDefined) DateTimeUtils.truncDate(dt.get, level) else null + case StringType => + val dt = DateTimeUtils.stringToDate(d.asInstanceOf[UTF8String]) + if (dt.isDefined) DateTimeUtils.truncDate(dt.get, level) else null + } + } + } else { + null + } + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + + if (isTruncNumber) { + val bdu = MathUtils.getClass.getName.stripSuffix("$") + + if (truncExpr.foldable) { + val d = data.genCode(ctx) + ev.copy(code = s""" + ${d.code} + boolean ${ev.isNull} = ${d.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $bdu.trunc(${d.value}, $truncFormat); + }""") + } else { + nullSafeCodeGen(ctx, ev, (doubleVal, truncParam) => + s"${ev.value} = $bdu.trunc($doubleVal, $truncParam);") + } + } else if (isTruncDate) { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + + if (truncExpr.foldable) { + if (truncFormat == -1) { + ev.copy(code = s""" + boolean ${ev.isNull} = true; + int ${ev.value} = ${ctx.defaultValue(DateType)};""") + } else { + val d = data.genCode(ctx) + val dt = ctx.freshName("dt") + val pre = s""" + ${d.code} + boolean ${ev.isNull} = ${d.isNull}; + int ${ev.value} = ${ctx.defaultValue(DateType)};""" + data.dataType match { + case DateType => + ev.copy(code = pre + s""" + if (!${ev.isNull}) { + ${ev.value} = $dtu.truncDate(${d.value}, $truncFormat); + }""") + case TimestampType => + val ts = ctx.freshName("ts") + ev.copy(code = pre + s""" + String $ts = $dtu.timestampToString(${d.value}); + scala.Option $dt = $dtu.stringToDate(UTF8String.fromString($ts)); + if (!${ev.isNull}) { + ${ev.value} = $dtu.truncDate((Integer)dt.get(), $truncFormat); + }""") + case StringType => + ev.copy(code = pre + s""" + scala.Option $dt = $dtu.stringToDate(${d.value}); + if (!${ev.isNull} && $dt.isDefined()) { + ${ev.value} = $dtu.truncDate((Integer)$dt.get(), $truncFormat); + }""") + } + } + } else { + nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { + val truncParam = ctx.freshName("truncParam") + val dt = ctx.freshName("dt") + val pre = s"int $truncParam = $dtu.parseTruncLevel($fmt);" + data.dataType match { + case DateType => + pre + s""" + if ($truncParam == -1) { + ${ev.isNull} = true; + } else { + ${ev.value} = $dtu.truncDate($dateVal, $truncParam); + }""" + case TimestampType => + val ts = ctx.freshName("ts") + pre + s""" + String $ts = $dtu.timestampToString($dateVal); + scala.Option $dt = $dtu.stringToDate(UTF8String.fromString($ts)); + if ($truncParam == -1 || $dt.isEmpty()) { + ${ev.isNull} = true; + } else { + ${ev.value} = $dtu.truncDate((Integer)$dt.get(), $truncParam); + }""" + case StringType => + pre + s""" + scala.Option $dt = $dtu.stringToDate($dateVal); + ${ev.value} = ${ctx.defaultValue(DateType)}; + if ($truncParam == -1 || $dt.isEmpty()) { + ${ev.isNull} = true; + } else { + ${ev.value} = $dtu.truncDate((Integer)$dt.get(), $truncParam); + }""" + } + }) + } + } else { + nullSafeCodeGen(ctx, ev, (dataVal, fmt) => s"${ev.isNull} = true;") + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala new file mode 100644 index 000000000000..cc826545fbd4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.math.{BigDecimal => JBigDecimal} + +/** + * Helper functions for BigDecimal. + */ +object MathUtils { + + /** + * Returns double type input truncated to scale decimal places. + */ + def trunc(input: Double, scale: Int): Double = { + trunc(JBigDecimal.valueOf(input), scale).doubleValue() + } + + /** + * Returns BigDecimal type input truncated to scale decimal places. + */ + def trunc(input: JBigDecimal, scale: Int): JBigDecimal = { + // Copy from (https://github.com/apache/hive/blob/release-2.3.0-rc0 + // /ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrunc.java#L471-L487) + val pow = if (scale >= 0) { + JBigDecimal.valueOf(Math.pow(10, scale)) + } else { + JBigDecimal.valueOf(Math.pow(10, Math.abs(scale))) + } + + val out = if (scale > 0) { + val longValue = input.multiply(pow).longValue() + JBigDecimal.valueOf(longValue).divide(pow) + } else if (scale == 0) { + JBigDecimal.valueOf(input.longValue()) + } else { + val longValue = input.divide(pow).longValue() + JBigDecimal.valueOf(longValue).multiply(pow) + } + out + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 89d99f9678cd..cdb239ead3af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -527,27 +527,6 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null) } - test("function trunc") { - def testTrunc(input: Date, fmt: String, expected: Date): Unit = { - checkEvaluation(TruncDate(Literal.create(input, DateType), Literal.create(fmt, StringType)), - expected) - checkEvaluation( - TruncDate(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)), - expected) - } - val date = Date.valueOf("2015-07-22") - Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt => - testTrunc(date, fmt, Date.valueOf("2015-01-01")) - } - Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt => - testTrunc(date, fmt, Date.valueOf("2015-07-01")) - } - testTrunc(date, "DD", null) - testTrunc(date, null, null) - testTrunc(null, "MON", null) - testTrunc(null, null, null) - } - test("from_unixtime") { val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index 4fe7b436982b..c65bc72f67fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -44,4 +46,85 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { assert(evaluate(Uuid()) !== evaluate(Uuid())) } + test("trunc numeric") { + def test(input: Double, fmt: Int, expected: Double): Unit = { + checkEvaluation(Trunc(Literal.create(input, DoubleType), + Literal.create(fmt, IntegerType)), + expected) + checkEvaluation(Trunc(Literal.create(input, DoubleType), + NonFoldableLiteral.create(fmt, IntegerType)), + expected) + } + + test(1234567891.1234567891, 4, 1234567891.1234) + test(1234567891.1234567891, -4, 1234560000) + test(1234567891.1234567891, 0, 1234567891) + test(0.123, -1, 0) + test(0.123, 0, 0) + + checkEvaluation(Trunc(Literal.create(1D, DoubleType), + NonFoldableLiteral.create(null, IntegerType)), + null) + checkEvaluation(Trunc(Literal.create(null, DoubleType), + NonFoldableLiteral.create(1, IntegerType)), + null) + checkEvaluation(Trunc(Literal.create(null, DoubleType), + NonFoldableLiteral.create(null, IntegerType)), + null) + } + + test("trunc date") { + def testDate(input: Date, fmt: String, expected: Date): Unit = { + checkEvaluation(Trunc(Literal.create(input, DateType), Literal.create(fmt, StringType)), + expected) + checkEvaluation( + Trunc(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)), + expected) + } + + def testString(input: String, fmt: String, expected: Date): Unit = { + checkEvaluation(Trunc(Literal.create(input, StringType), Literal.create(fmt, StringType)), + expected) + checkEvaluation( + Trunc(Literal.create(input, StringType), NonFoldableLiteral.create(fmt, StringType)), + expected) + } + + def testTimestamp(input: Timestamp, fmt: String, expected: Date): Unit = { + checkEvaluation(Trunc(Literal.create(input, TimestampType), Literal.create(fmt, StringType)), + expected) + checkEvaluation( + Trunc(Literal.create(input, TimestampType), NonFoldableLiteral.create(fmt, StringType)), + expected) + } + + val dateStr = "2015-07-22" + val date = Date.valueOf(dateStr) + val ts = new Timestamp(date.getTime) + + Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt => + testDate(date, fmt, Date.valueOf("2015-01-01")) + testString(dateStr, fmt, Date.valueOf("2015-01-01")) + testTimestamp(ts, fmt, Date.valueOf("2015-01-01")) + } + Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt => + testDate(date, fmt, Date.valueOf("2015-07-01")) + testString(dateStr, fmt, Date.valueOf("2015-07-01")) + testTimestamp(ts, fmt, Date.valueOf("2015-07-01")) + } + testDate(date, "DD", null) + testDate(date, null, null) + testDate(null, "MON", null) + testDate(null, null, null) + + testString(dateStr, "DD", null) + testString(dateStr, null, null) + testString(null, "MON", null) + testString(null, null, null) + + testTimestamp(ts, "DD", null) + testTimestamp(ts, null, null) + testTimestamp(null, "MON", null) + testTimestamp(null, null, null) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MathUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MathUtilsSuite.scala new file mode 100644 index 000000000000..a3afe26bb408 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MathUtilsSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.MathUtils._ + +class MathUtilsSuite extends SparkFunSuite { + + test("trunc number") { + val bg = 1234567891.1234567891D + assert(trunc(bg, 4) === 1234567891.1234) + assert(trunc(bg, -4) === 1234560000) + assert(trunc(bg, 0) === 1234567891) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6bbdfa3ad189..c3785d477676 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2090,6 +2090,20 @@ object functions { */ def radians(columnName: String): Column = radians(Column(columnName)) + /** + * Returns numeric truncated by specified decimal places. + * If scale is positive or 0, numeric is truncated to the absolute value of scale number + * of places to the right of the decimal point. + * If scale is negative, numeric is truncated to the absolute value of scale + 1 number + * of places to the left of the decimal point. + * + * @group math_funcs + * @since 2.3.0 + */ + def trunc(numeric: Column, scale: Int): Column = withExpr { + Trunc(numeric.expr, Literal(scale)) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Misc functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -2787,7 +2801,7 @@ object functions { * @since 1.5.0 */ def trunc(date: Column, format: String): Column = withExpr { - TruncDate(date.expr, Literal(format)) + Trunc(date.expr, Literal(format)) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index adea2bfa82cd..3f8de29a7e9d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -14,14 +14,25 @@ create temporary view ttf1 as select * from values (1, 2), (2, 3) as ttf1(current_date, current_timestamp); - + select current_date, current_timestamp from ttf1; create temporary view ttf2 as select * from values (1, 2), (2, 3) as ttf2(a, b); - + select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2; select a, b from ttf2 order by a, current_date; + +-- trunc date +select trunc('2015-07-22', 'yyyy'), trunc('2015-07-22', 'YYYY'), + trunc('2015-07-22', 'year'), trunc('2015-07-22', 'YEAR'), + trunc(to_date('2015-07-22'), 'yy'), trunc(to_date('2015-07-22'), 'YY'); +select trunc('2015-07-22', 'month'), trunc('2015-07-22', 'MONTH'), + trunc('2015-07-22', 'mon'), trunc('2015-07-22', 'MON'), + trunc(to_date('2015-07-22'), 'mm'), trunc(to_date('2015-07-22'), 'MM'); +select trunc('2015-07-22', 'DD'), trunc('2015-07-22', null); +select trunc('2015-07-2200', 'DD'), trunc('123', null); +select trunc(null, 'MON'), trunc(null, null); diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 15d981985c55..dc3739c510b1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -96,3 +96,11 @@ select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11); -- pmod select pmod(-7, 2), pmod(0, 2), pmod(7, 0), pmod(7, null), pmod(null, 2), pmod(null, null); select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint), cast(0 as smallint)); + +-- trunc +select trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, -4), trunc(1234567891.1234567891, 0), trunc(1234567891.1234567891); +select trunc(1234567891.1234567891, null), trunc(null, 4), trunc(null, null); +select trunc(1234567891.1234567891, 'yyyy'); +select trunc(to_date('2015-07-22'), 4); +select trunc('2015-07-22', 4); +select trunc(false, 4); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 7b2f46f6c2a6..150dbc38e1f1 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 9 +-- Number of queries: 14 -- !query 0 @@ -79,3 +79,47 @@ struct -- !query 8 output 1 2 2 3 + + +-- !query 9 +select trunc('2015-07-22', 'yyyy'), trunc('2015-07-22', 'YYYY'), + trunc('2015-07-22', 'year'), trunc('2015-07-22', 'YEAR'), + trunc(to_date('2015-07-22'), 'yy'), trunc(to_date('2015-07-22'), 'YY') +-- !query 9 schema +struct +-- !query 9 output +2015-01-01 2015-01-01 2015-01-01 2015-01-01 2015-01-01 2015-01-01 + + +-- !query 10 +select trunc('2015-07-22', 'month'), trunc('2015-07-22', 'MONTH'), + trunc('2015-07-22', 'mon'), trunc('2015-07-22', 'MON'), + trunc(to_date('2015-07-22'), 'mm'), trunc(to_date('2015-07-22'), 'MM') +-- !query 10 schema +struct +-- !query 10 output +2015-07-01 2015-07-01 2015-07-01 2015-07-01 2015-07-01 2015-07-01 + + +-- !query 11 +select trunc('2015-07-22', 'DD'), trunc('2015-07-22', null) +-- !query 11 schema +struct +-- !query 11 output +NULL NULL + + +-- !query 12 +select trunc('2015-07-2200', 'DD'), trunc('123', null) +-- !query 12 schema +struct +-- !query 12 output +NULL NULL + + +-- !query 13 +select trunc(null, 'MON'), trunc(null, null) +-- !query 13 schema +struct +-- !query 13 output +NULL NULL diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 237b618a8b90..b957c70c9024 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 59 +-- Number of queries: 65 -- !query 0 @@ -484,3 +484,55 @@ select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint) struct -- !query 58 output NULL NULL + + +-- !query 59 +select trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, -4), trunc(1234567891.1234567891, 0), trunc(1234567891.1234567891) +-- !query 59 schema +struct +-- !query 59 output +1234567891.1234 1234560000 1234567891 1234567891 + + +-- !query 60 +select trunc(1234567891.1234567891, null), trunc(null, 4), trunc(null, null) +-- !query 60 schema +struct +-- !query 60 output +NULL NULL NULL + + +-- !query 61 +select trunc(1234567891.1234567891, 'yyyy') +-- !query 61 schema +struct<> +-- !query 61 output +org.apache.spark.sql.AnalysisException +cannot resolve 'trunc(1234567891.1234567891BD, 'yyyy')' due to data type mismatch: argument 2 requires int type, however, ''yyyy'' is of string type.; line 1 pos 7 + + +-- !query 62 +select trunc(to_date('2015-07-22'), 4) +-- !query 62 schema +struct<> +-- !query 62 output +org.apache.spark.sql.AnalysisException +cannot resolve 'trunc(to_date('2015-07-22'), 4)' due to data type mismatch: argument 2 requires string type, however, '4' is of int type.; line 1 pos 7 + + +-- !query 63 +select trunc('2015-07-22', 4) +-- !query 63 schema +struct<> +-- !query 63 output +org.apache.spark.sql.AnalysisException +cannot resolve 'trunc('2015-07-22', 4)' due to data type mismatch: argument 2 requires string type, however, '4' is of int type.; line 1 pos 7 + + +-- !query 64 +select trunc(false, 4) +-- !query 64 schema +struct<> +-- !query 64 output +org.apache.spark.sql.AnalysisException +cannot resolve 'trunc(false, 4)' due to data type mismatch: argument 1 requires (date or string or timestamp or double or decimal) type, however, 'false' is of boolean type.; line 1 pos 7