Skip to content

Commit a5ade70

Browse files
committed
Support TRUNC (number)
1 parent 6b68d61 commit a5ade70

File tree

12 files changed

+337
-106
lines changed

12 files changed

+337
-106
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,6 @@ object FunctionRegistry {
355355
expression[ParseToDate]("to_date"),
356356
expression[ToUnixTimestamp]("to_unix_timestamp"),
357357
expression[ToUTCTimestamp]("to_utc_timestamp"),
358-
expression[TruncDate]("trunc"),
359358
expression[UnixTimestamp]("unix_timestamp"),
360359
expression[WeekOfYear]("weekofyear"),
361360
expression[Year]("year"),
@@ -388,6 +387,7 @@ object FunctionRegistry {
388387
expression[CurrentDatabase]("current_database"),
389388
expression[CallMethodViaReflection]("reflect"),
390389
expression[CallMethodViaReflection]("java_method"),
390+
expression[Trunc]("trunc"),
391391

392392
// grouping sets
393393
expression[Cube]("cube"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,86 +1227,6 @@ case class ParseToTimestamp(left: Expression, format: Option[Expression], child:
12271227
override def dataType: DataType = TimestampType
12281228
}
12291229

1230-
/**
1231-
* Returns date truncated to the unit specified by the format.
1232-
*/
1233-
// scalastyle:off line.size.limit
1234-
@ExpressionDescription(
1235-
usage = "_FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.",
1236-
extended = """
1237-
Examples:
1238-
> SELECT _FUNC_('2009-02-12', 'MM');
1239-
2009-02-01
1240-
> SELECT _FUNC_('2015-10-27', 'YEAR');
1241-
2015-01-01
1242-
""")
1243-
// scalastyle:on line.size.limit
1244-
case class TruncDate(date: Expression, format: Expression)
1245-
extends BinaryExpression with ImplicitCastInputTypes {
1246-
override def left: Expression = date
1247-
override def right: Expression = format
1248-
1249-
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
1250-
override def dataType: DataType = DateType
1251-
override def nullable: Boolean = true
1252-
override def prettyName: String = "trunc"
1253-
1254-
private lazy val truncLevel: Int =
1255-
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
1256-
1257-
override def eval(input: InternalRow): Any = {
1258-
val level = if (format.foldable) {
1259-
truncLevel
1260-
} else {
1261-
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
1262-
}
1263-
if (level == -1) {
1264-
// unknown format
1265-
null
1266-
} else {
1267-
val d = date.eval(input)
1268-
if (d == null) {
1269-
null
1270-
} else {
1271-
DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
1272-
}
1273-
}
1274-
}
1275-
1276-
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1277-
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
1278-
1279-
if (format.foldable) {
1280-
if (truncLevel == -1) {
1281-
ev.copy(code = s"""
1282-
boolean ${ev.isNull} = true;
1283-
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
1284-
} else {
1285-
val d = date.genCode(ctx)
1286-
ev.copy(code = s"""
1287-
${d.code}
1288-
boolean ${ev.isNull} = ${d.isNull};
1289-
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
1290-
if (!${ev.isNull}) {
1291-
${ev.value} = $dtu.truncDate(${d.value}, $truncLevel);
1292-
}""")
1293-
}
1294-
} else {
1295-
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
1296-
val form = ctx.freshName("form")
1297-
s"""
1298-
int $form = $dtu.parseTruncLevel($fmt);
1299-
if ($form == -1) {
1300-
${ev.isNull} = true;
1301-
} else {
1302-
${ev.value} = $dtu.truncDate($dateVal, $form);
1303-
}
1304-
"""
1305-
})
1306-
}
1307-
}
1308-
}
1309-
13101230
/**
13111231
* Returns the number of days from startDate to endDate.
13121232
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.codegen._
22+
import org.apache.spark.sql.catalyst.util.{BigDecimalUtils, DateTimeUtils}
2223
import org.apache.spark.sql.types._
24+
import org.apache.spark.unsafe.types.UTF8String
2325

2426
/**
2527
* Print the result of an expression to stderr (used for debugging codegen).
@@ -104,3 +106,138 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable {
104106
override def nullable: Boolean = false
105107
override def prettyName: String = "current_database"
106108
}
109+
110+
/**
111+
* Returns date truncated to the unit specified by the format or
112+
* numeric truncated to scale decimal places.
113+
*/
114+
// scalastyle:off line.size.limit
115+
@ExpressionDescription(
116+
usage = """
117+
_FUNC_(data, fmt) - Returns `data` truncated by the format model `fmt`.
118+
If `data` is DateType, returns `data` with the time portion of the day truncated to the unit specified by the format model `fmt`.
119+
If `data` is DoubleType, returns `data` truncated to `fmt` decimal places.
120+
""",
121+
extended = """
122+
Examples:
123+
> SELECT _FUNC_('2009-02-12', 'MM');
124+
2009-02-01
125+
> SELECT _FUNC_('2015-10-27', 'YEAR');
126+
2015-01-01
127+
> SELECT _FUNC_(1234567891.1234567891, 4);
128+
1234567891.1234
129+
> SELECT _FUNC_(1234567891.1234567891, -4);
130+
1234560000
131+
""")
132+
// scalastyle:on line.size.limit
133+
case class Trunc(data: Expression, format: Expression = Literal(0))
134+
extends BinaryExpression with ImplicitCastInputTypes {
135+
136+
def this(numeric: Expression) = {
137+
this(numeric, Literal(0))
138+
}
139+
140+
override def left: Expression = data
141+
override def right: Expression = format
142+
143+
override def dataType: DataType = data.dataType
144+
145+
override def inputTypes: Seq[AbstractDataType] =
146+
Seq(TypeCollection(DoubleType, DateType), TypeCollection(StringType, IntegerType))
147+
148+
override def nullable: Boolean = true
149+
override def prettyName: String = "trunc"
150+
151+
private lazy val truncFormat: Int = dataType match {
152+
case doubleType: DoubleType =>
153+
format.eval().asInstanceOf[Int]
154+
case dateType: DateType =>
155+
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
156+
}
157+
158+
override def eval(input: InternalRow): Any = {
159+
val d = data.eval(input)
160+
val form = format.eval()
161+
if (null == d || null == form) {
162+
null
163+
} else {
164+
dataType match {
165+
case doubleType: DoubleType =>
166+
val scale = if (format.foldable) {
167+
truncFormat
168+
} else {
169+
format.eval().asInstanceOf[Int]
170+
}
171+
BigDecimalUtils.trunc(d.asInstanceOf[Double], scale).doubleValue()
172+
case dateType: DateType =>
173+
val level = if (format.foldable) {
174+
truncFormat
175+
} else {
176+
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
177+
}
178+
if (level == -1) {
179+
// unknown format
180+
null
181+
} else {
182+
DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
183+
}
184+
}
185+
}
186+
187+
}
188+
189+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
190+
191+
dataType match {
192+
case doubleType: DoubleType =>
193+
val bdu = BigDecimalUtils.getClass.getName.stripSuffix("$")
194+
195+
if (format.foldable) {
196+
val d = data.genCode(ctx)
197+
ev.copy(code = s"""
198+
${d.code}
199+
boolean ${ev.isNull} = ${d.isNull};
200+
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
201+
if (!${ev.isNull}) {
202+
${ev.value} = $bdu.trunc(${d.value}, $truncFormat).doubleValue();
203+
}""")
204+
} else {
205+
nullSafeCodeGen(ctx, ev, (doubleVal, fmt) => {
206+
s"${ev.value} = $bdu.trunc($doubleVal, $fmt).doubleValue();"
207+
})
208+
}
209+
case dateType: DateType =>
210+
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
211+
212+
if (format.foldable) {
213+
if (truncFormat == -1) {
214+
ev.copy(code = s"""
215+
boolean ${ev.isNull} = true;
216+
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
217+
} else {
218+
val d = data.genCode(ctx)
219+
ev.copy(code = s"""
220+
${d.code}
221+
boolean ${ev.isNull} = ${d.isNull};
222+
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
223+
if (!${ev.isNull}) {
224+
${ev.value} = $dtu.truncDate(${d.value}, $truncFormat);
225+
}""")
226+
}
227+
} else {
228+
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
229+
val form = ctx.freshName("form")
230+
s"""
231+
int $form = $dtu.parseTruncLevel($fmt);
232+
if ($form == -1) {
233+
${ev.isNull} = true;
234+
} else {
235+
${ev.value} = $dtu.truncDate($dateVal, $form);
236+
}
237+
"""
238+
})
239+
}
240+
}
241+
242+
}
243+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.util
19+
20+
import java.math.{BigDecimal => JBigDecimal}
21+
22+
/**
23+
* Helper functions for BigDecimal.
24+
*/
25+
object BigDecimalUtils {
26+
27+
/**
28+
* Returns double type input truncated to scale decimal places.
29+
*/
30+
def trunc(input: Double, scale: Int): JBigDecimal = {
31+
trunc(JBigDecimal.valueOf(input), scale)
32+
}
33+
34+
/**
35+
* Returns BigDecimal type input truncated to scale decimal places.
36+
*/
37+
def trunc(input: JBigDecimal, scale: Int): JBigDecimal = {
38+
39+
val pow = if (scale >= 0) {
40+
JBigDecimal.valueOf(Math.pow(10, scale))
41+
} else {
42+
JBigDecimal.valueOf(Math.pow(10, Math.abs(scale)))
43+
}
44+
45+
val out = if (scale > 0) {
46+
val longValue = input.multiply(pow).longValue()
47+
JBigDecimal.valueOf(longValue).divide(pow)
48+
} else if (scale == 0) {
49+
JBigDecimal.valueOf(input.longValue())
50+
} else {
51+
val longValue = input.divide(pow).longValue()
52+
JBigDecimal.valueOf(longValue).multiply(pow)
53+
}
54+
out
55+
}
56+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -513,27 +513,6 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
513513
NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null)
514514
}
515515

516-
test("function trunc") {
517-
def testTrunc(input: Date, fmt: String, expected: Date): Unit = {
518-
checkEvaluation(TruncDate(Literal.create(input, DateType), Literal.create(fmt, StringType)),
519-
expected)
520-
checkEvaluation(
521-
TruncDate(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)),
522-
expected)
523-
}
524-
val date = Date.valueOf("2015-07-22")
525-
Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt =>
526-
testTrunc(date, fmt, Date.valueOf("2015-01-01"))
527-
}
528-
Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt =>
529-
testTrunc(date, fmt, Date.valueOf("2015-07-01"))
530-
}
531-
testTrunc(date, "DD", null)
532-
testTrunc(date, null, null)
533-
testTrunc(null, "MON", null)
534-
testTrunc(null, null, null)
535-
}
536-
537516
test("from_unixtime") {
538517
val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
539518
val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import java.sql.Date
21+
2022
import org.apache.spark.SparkFunSuite
2123
import org.apache.spark.sql.types._
2224

@@ -39,4 +41,50 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
3941
checkEvaluation(AssertTrue(Cast(Literal(1), BooleanType)), null)
4042
}
4143

44+
test("trunc") {
45+
46+
// numeric
47+
def testTruncNumber(input: Double, fmt: Int, expected: Double): Unit = {
48+
checkEvaluation(Trunc(Literal.create(input, DoubleType),
49+
Literal.create(fmt, IntegerType)),
50+
expected)
51+
checkEvaluation(Trunc(Literal.create(input, DoubleType),
52+
NonFoldableLiteral.create(fmt, IntegerType)),
53+
expected)
54+
}
55+
56+
testTruncNumber(1234567891.1234567891, 4, 1234567891.1234)
57+
testTruncNumber(1234567891.1234567891, -4, 1234560000)
58+
testTruncNumber(1234567891.1234567891, 0, 1234567891)
59+
60+
checkEvaluation(Trunc(Literal.create(1D, DoubleType),
61+
NonFoldableLiteral.create(null, IntegerType)),
62+
null)
63+
checkEvaluation(Trunc(Literal.create(null, DoubleType),
64+
NonFoldableLiteral.create(1, IntegerType)),
65+
null)
66+
checkEvaluation(Trunc(Literal.create(null, DoubleType),
67+
NonFoldableLiteral.create(null, IntegerType)),
68+
null)
69+
70+
// date
71+
def testTruncDate(input: Date, fmt: String, expected: Date): Unit = {
72+
checkEvaluation(Trunc(Literal.create(input, DateType), Literal.create(fmt, StringType)),
73+
expected)
74+
checkEvaluation(
75+
Trunc(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)),
76+
expected)
77+
}
78+
val date = Date.valueOf("2015-07-22")
79+
Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt =>
80+
testTruncDate(date, fmt, Date.valueOf("2015-01-01"))
81+
}
82+
Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt =>
83+
testTruncDate(date, fmt, Date.valueOf("2015-07-01"))
84+
}
85+
testTruncDate(date, "DD", null)
86+
testTruncDate(date, null, null)
87+
testTruncDate(null, "MON", null)
88+
testTruncDate(null, null, null)
89+
}
4290
}

0 commit comments

Comments
 (0)