Skip to content

Commit 91c2ee2

Browse files
committed
[SPARK-29713][SQL] Support Interval Unit Abbreviations in Interval Literals
1 parent 14337f6 commit 91c2ee2

File tree

7 files changed

+184
-119
lines changed

7 files changed

+184
-119
lines changed

sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,7 +1446,7 @@ CURRENT_USER: 'CURRENT_USER';
14461446
DATA: 'DATA';
14471447
DATABASE: 'DATABASE';
14481448
DATABASES: 'DATABASES' | 'SCHEMAS';
1449-
DAY: 'DAY';
1449+
DAY: 'DAY' | 'D';
14501450
DAYS: 'DAYS';
14511451
DBPROPERTIES: 'DBPROPERTIES';
14521452
DEFINED: 'DEFINED';
@@ -1491,7 +1491,7 @@ GRANT: 'GRANT';
14911491
GROUP: 'GROUP';
14921492
GROUPING: 'GROUPING';
14931493
HAVING: 'HAVING';
1494-
HOUR: 'HOUR';
1494+
HOUR: 'HOUR' | 'H' | 'HR' | 'HRS';
14951495
HOURS: 'HOURS';
14961496
IF: 'IF';
14971497
IGNORE: 'IGNORE';
@@ -1528,13 +1528,13 @@ LOCKS: 'LOCKS';
15281528
LOGICAL: 'LOGICAL';
15291529
MACRO: 'MACRO';
15301530
MAP: 'MAP';
1531-
MICROSECOND: 'MICROSECOND';
1531+
MICROSECOND: 'MICROSECOND' | 'US' | 'USEC' | 'USECS' | 'USECONDS';
15321532
MICROSECONDS: 'MICROSECONDS';
1533-
MILLISECOND: 'MILLISECOND';
1533+
MILLISECOND: 'MILLISECOND' | 'MS' | 'MSEC' | 'MSECS' | 'MSECONDS';
15341534
MILLISECONDS: 'MILLISECONDS';
1535-
MINUTE: 'MINUTE';
1535+
MINUTE: 'MINUTE' | 'M' | 'MIN' | 'MINS';
15361536
MINUTES: 'MINUTES';
1537-
MONTH: 'MONTH';
1537+
MONTH: 'MONTH' | 'MON' | 'MONS';
15381538
MONTHS: 'MONTHS';
15391539
MSCK: 'MSCK';
15401540
NAMESPACE: 'NAMESPACE';
@@ -1594,7 +1594,7 @@ ROLLUP: 'ROLLUP';
15941594
ROW: 'ROW';
15951595
ROWS: 'ROWS';
15961596
SCHEMA: 'SCHEMA';
1597-
SECOND: 'SECOND';
1597+
SECOND: 'SECOND' | 'S' | 'SEC' | 'SECS';
15981598
SECONDS: 'SECONDS';
15991599
SELECT: 'SELECT';
16001600
SEMI: 'SEMI';
@@ -1648,13 +1648,13 @@ USER: 'USER';
16481648
USING: 'USING';
16491649
VALUES: 'VALUES';
16501650
VIEW: 'VIEW';
1651-
WEEK: 'WEEK';
1651+
WEEK: 'WEEK' | 'W';
16521652
WEEKS: 'WEEKS';
16531653
WHEN: 'WHEN';
16541654
WHERE: 'WHERE';
16551655
WINDOW: 'WINDOW';
16561656
WITH: 'WITH';
1657-
YEAR: 'YEAR';
1657+
YEAR: 'YEAR' | 'Y' | 'YR' | 'YRS';
16581658
YEARS: 'YEARS';
16591659
//============================
16601660
// End of the keywords list

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
3636
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
3737
import org.apache.spark.sql.catalyst.plans._
3838
import org.apache.spark.sql.catalyst.plans.logical._
39+
import org.apache.spark.sql.catalyst.util.{IntervalUnit, IntervalUtils}
3940
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
40-
import org.apache.spark.sql.catalyst.util.IntervalUtils
4141
import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform}
4242
import org.apache.spark.sql.internal.SQLConf
4343
import org.apache.spark.sql.types._
@@ -50,6 +50,7 @@ import org.apache.spark.util.random.RandomSampler
5050
*/
5151
class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging {
5252
import ParserUtils._
53+
import IntervalUnit._
5354

5455
def this() = this(new SQLConf())
5556

@@ -103,11 +104,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
103104

104105
override def visitSingleInterval(ctx: SingleIntervalContext): CalendarInterval = {
105106
withOrigin(ctx) {
106-
val units = ctx.intervalUnit().asScala.map {
107-
u => normalizeInternalUnit(u.getText.toLowerCase(Locale.ROOT))
108-
}.toArray
109-
val values = ctx.intervalValue().asScala.map(getIntervalValue).toArray
110107
try {
108+
val units = ctx.intervalUnit().asScala.map(u => IntervalUnit.fromString(u.getText)).toArray
109+
val values = ctx.intervalValue().asScala.map(getIntervalValue).toArray
111110
IntervalUtils.fromUnitStrings(units, values)
112111
} catch {
113112
case i: IllegalArgumentException =>
@@ -1960,24 +1959,23 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
19601959
import ctx._
19611960
val s = getIntervalValue(value)
19621961
try {
1963-
val unitText = unit.getText.toLowerCase(Locale.ROOT)
1964-
val interval = (unitText, Option(to).map(_.getText.toLowerCase(Locale.ROOT))) match {
1962+
val ut = IntervalUnit.fromStringStrict(unit.getText)
1963+
val interval = (ut, Option(to).map(u => IntervalUnit.fromStringStrict(u.getText))) match {
19651964
case (u, None) =>
1966-
IntervalUtils.fromUnitStrings(Array(normalizeInternalUnit(u)), Array(s))
1967-
case ("year", Some("month")) =>
1968-
IntervalUtils.fromYearMonthString(s)
1969-
case ("day", Some("hour")) =>
1970-
IntervalUtils.fromDayTimeString(s, "day", "hour")
1971-
case ("day", Some("minute")) =>
1972-
IntervalUtils.fromDayTimeString(s, "day", "minute")
1973-
case ("day", Some("second")) =>
1974-
IntervalUtils.fromDayTimeString(s, "day", "second")
1975-
case ("hour", Some("minute")) =>
1976-
IntervalUtils.fromDayTimeString(s, "hour", "minute")
1977-
case ("hour", Some("second")) =>
1978-
IntervalUtils.fromDayTimeString(s, "hour", "second")
1979-
case ("minute", Some("second")) =>
1980-
IntervalUtils.fromDayTimeString(s, "minute", "second")
1965+
IntervalUtils.fromUnitStrings(Array(u), Array(s))
1966+
case (IntervalUnit.YEAR, Some(IntervalUnit.MONTH)) => IntervalUtils.fromYearMonthString(s)
1967+
case (IntervalUnit.DAY, Some(IntervalUnit.HOUR)) =>
1968+
IntervalUtils.fromDayTimeString(s, IntervalUnit.DAY, IntervalUnit.HOUR)
1969+
case (IntervalUnit.DAY, Some(IntervalUnit.MINUTE)) =>
1970+
IntervalUtils.fromDayTimeString(s, IntervalUnit.DAY, IntervalUnit.MINUTE)
1971+
case (IntervalUnit.DAY, Some(IntervalUnit.SECOND)) =>
1972+
IntervalUtils.fromDayTimeString(s, IntervalUnit.DAY, IntervalUnit.SECOND)
1973+
case (IntervalUnit.HOUR, Some(IntervalUnit.MINUTE)) =>
1974+
IntervalUtils.fromDayTimeString(s, IntervalUnit.HOUR, IntervalUnit.MINUTE)
1975+
case (IntervalUnit.HOUR, Some(IntervalUnit.SECOND)) =>
1976+
IntervalUtils.fromDayTimeString(s, IntervalUnit.HOUR, IntervalUnit.SECOND)
1977+
case (IntervalUnit.MINUTE, Some(IntervalUnit.SECOND)) =>
1978+
IntervalUtils.fromDayTimeString(s, IntervalUnit.MINUTE, IntervalUnit.SECOND)
19811979
case (from, Some(t)) =>
19821980
throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx)
19831981
}
@@ -2000,11 +1998,6 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
20001998
}
20011999
}
20022000

2003-
// Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
2004-
private def normalizeInternalUnit(s: String): String = {
2005-
if (s.endsWith("s")) s.substring(0, s.length - 1) else s
2006-
}
2007-
20082001
/* ********************************************************************************************
20092002
* DataType parsing
20102003
* ******************************************************************************************** */

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala

Lines changed: 76 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.util
1919

20+
import java.util.Locale
2021
import java.util.concurrent.TimeUnit
2122

2223
import scala.util.control.NonFatal
@@ -26,6 +27,8 @@ import org.apache.spark.sql.types.Decimal
2627
import org.apache.spark.unsafe.types.CalendarInterval
2728

2829
object IntervalUtils {
30+
import IntervalUnit._
31+
2932
final val MONTHS_PER_YEAR: Int = 12
3033
final val MONTHS_PER_QUARTER: Byte = 3
3134
final val YEARS_PER_MILLENNIUM: Int = 1000
@@ -126,13 +129,13 @@ object IntervalUtils {
126129
}
127130

128131
private def toLongWithRange(
129-
fieldName: String,
132+
fieldName: IntervalUnit,
130133
s: String,
131134
minValue: Long,
132135
maxValue: Long): Long = {
133136
val result = if (s == null) 0L else s.toLong
134137
require(minValue <= result && result <= maxValue,
135-
s"$fieldName $result outside range [$minValue, $maxValue]")
138+
s"${fieldName.toString} $result outside range [$minValue, $maxValue]")
136139

137140
result
138141
}
@@ -148,8 +151,8 @@ object IntervalUtils {
148151
require(input != null, "Interval year-month string must be not null")
149152
def toInterval(yearStr: String, monthStr: String): CalendarInterval = {
150153
try {
151-
val years = toLongWithRange("year", yearStr, 0, Integer.MAX_VALUE).toInt
152-
val months = toLongWithRange("month", monthStr, 0, 11).toInt
154+
val years = toLongWithRange(YEAR, yearStr, 0, Integer.MAX_VALUE).toInt
155+
val months = toLongWithRange(MONTH, monthStr, 0, 11).toInt
153156
val totalMonths = Math.addExact(Math.multiplyExact(years, 12), months)
154157
new CalendarInterval(totalMonths, 0, 0)
155158
} catch {
@@ -176,7 +179,7 @@ object IntervalUtils {
176179
* adapted from HiveIntervalDayTime.valueOf
177180
*/
178181
def fromDayTimeString(s: String): CalendarInterval = {
179-
fromDayTimeString(s, "day", "second")
182+
fromDayTimeString(s, DAY, SECOND)
180183
}
181184

182185
private val dayTimePattern =
@@ -191,7 +194,7 @@ object IntervalUtils {
191194
* - HOUR TO (MINUTE|SECOND)
192195
* - MINUTE TO SECOND
193196
*/
194-
def fromDayTimeString(input: String, from: String, to: String): CalendarInterval = {
197+
def fromDayTimeString(input: String, from: IntervalUnit, to: IntervalUnit): CalendarInterval = {
195198
require(input != null, "Interval day-time string must be not null")
196199
assert(input.length == input.trim.length)
197200
val m = dayTimePattern.pattern.matcher(input)
@@ -202,34 +205,34 @@ object IntervalUtils {
202205
val days = if (m.group(2) == null) {
203206
0
204207
} else {
205-
toLongWithRange("day", m.group(3), 0, Integer.MAX_VALUE).toInt
208+
toLongWithRange(DAY, m.group(3), 0, Integer.MAX_VALUE).toInt
206209
}
207210
var hours: Long = 0L
208211
var minutes: Long = 0L
209212
var seconds: Long = 0L
210-
if (m.group(5) != null || from == "minute") { // 'HH:mm:ss' or 'mm:ss minute'
211-
hours = toLongWithRange("hour", m.group(5), 0, 23)
212-
minutes = toLongWithRange("minute", m.group(6), 0, 59)
213-
seconds = toLongWithRange("second", m.group(7), 0, 59)
213+
if (m.group(5) != null || from == MINUTE) { // 'HH:mm:ss' or 'mm:ss minute'
214+
hours = toLongWithRange(HOUR, m.group(5), 0, 23)
215+
minutes = toLongWithRange(MINUTE, m.group(6), 0, 59)
216+
seconds = toLongWithRange(SECOND, m.group(7), 0, 59)
214217
} else if (m.group(8) != null) { // 'mm:ss.nn'
215-
minutes = toLongWithRange("minute", m.group(6), 0, 59)
216-
seconds = toLongWithRange("second", m.group(7), 0, 59)
218+
minutes = toLongWithRange(MINUTE, m.group(6), 0, 59)
219+
seconds = toLongWithRange(SECOND, m.group(7), 0, 59)
217220
} else { // 'HH:mm'
218-
hours = toLongWithRange("hour", m.group(6), 0, 23)
219-
minutes = toLongWithRange("second", m.group(7), 0, 59)
221+
hours = toLongWithRange(HOUR, m.group(6), 0, 23)
222+
minutes = toLongWithRange(SECOND, m.group(7), 0, 59)
220223
}
221224
// Hive allow nanosecond precision interval
222225
var secondsFraction = parseNanos(m.group(9), seconds < 0)
223226
to match {
224-
case "hour" =>
227+
case HOUR =>
225228
minutes = 0
226229
seconds = 0
227230
secondsFraction = 0
228-
case "minute" =>
231+
case MINUTE =>
229232
seconds = 0
230233
secondsFraction = 0
231-
case "second" =>
232-
// No-op
234+
case SECOND =>
235+
// No-op
233236
case _ =>
234237
throw new IllegalArgumentException(
235238
s"Cannot support (interval '$input' $from to $to) expression")
@@ -246,36 +249,36 @@ object IntervalUtils {
246249
}
247250
}
248251

249-
def fromUnitStrings(units: Array[String], values: Array[String]): CalendarInterval = {
250-
assert(units.length == values.length)
252+
def fromUnitStrings(units: Array[IntervalUnit], fields: Array[String]): CalendarInterval = {
253+
assert(units.length == fields.length)
251254
var months: Int = 0
252255
var days: Int = 0
253256
var microseconds: Long = 0
254257
var i = 0
255258
while (i < units.length) {
256259
try {
257260
units(i) match {
258-
case "year" =>
259-
months = Math.addExact(months, Math.multiplyExact(values(i).toInt, 12))
260-
case "month" =>
261-
months = Math.addExact(months, values(i).toInt)
262-
case "week" =>
263-
days = Math.addExact(days, Math.multiplyExact(values(i).toInt, 7))
264-
case "day" =>
265-
days = Math.addExact(days, values(i).toInt)
266-
case "hour" =>
267-
val hoursUs = Math.multiplyExact(values(i).toLong, MICROS_PER_HOUR)
261+
case YEAR =>
262+
months = Math.addExact(months, Math.multiplyExact(fields(i).toInt, 12))
263+
case MONTH =>
264+
months = Math.addExact(months, fields(i).toInt)
265+
case WEEK =>
266+
days = Math.addExact(days, Math.multiplyExact(fields(i).toInt, 7))
267+
case DAY =>
268+
days = Math.addExact(days, fields(i).toInt)
269+
case HOUR =>
270+
val hoursUs = Math.multiplyExact(fields(i).toLong, MICROS_PER_HOUR)
268271
microseconds = Math.addExact(microseconds, hoursUs)
269-
case "minute" =>
270-
val minutesUs = Math.multiplyExact(values(i).toLong, MICROS_PER_MINUTE)
272+
case MINUTE =>
273+
val minutesUs = Math.multiplyExact(fields(i).toLong, MICROS_PER_MINUTE)
271274
microseconds = Math.addExact(microseconds, minutesUs)
272-
case "second" =>
273-
microseconds = Math.addExact(microseconds, parseSecondNano(values(i)))
274-
case "millisecond" =>
275-
val millisUs = Math.multiplyExact(values(i).toLong, DateTimeUtils.MICROS_PER_MILLIS)
275+
case SECOND =>
276+
microseconds = Math.addExact(microseconds, parseSecondNano(fields(i)))
277+
case MILLISECOND =>
278+
val millisUs = Math.multiplyExact(fields(i).toLong, DateTimeUtils.MICROS_PER_MILLIS)
276279
microseconds = Math.addExact(microseconds, millisUs)
277-
case "microsecond" =>
278-
microseconds = Math.addExact(microseconds, values(i).toLong)
280+
case MICROSECOND =>
281+
microseconds = Math.addExact(microseconds, fields(i).toLong)
279282
}
280283
} catch {
281284
case e: Exception =>
@@ -293,7 +296,7 @@ object IntervalUtils {
293296
val alignedStr = if (nanosStr.length < maxNanosLen) {
294297
(nanosStr + "000000000").substring(0, maxNanosLen)
295298
} else nanosStr
296-
val nanos = toLongWithRange("nanosecond", alignedStr, 0L, 999999999L)
299+
val nanos = toLongWithRange(NANOSECOND, alignedStr, 0L, 999999999L)
297300
val micros = nanos / DateTimeUtils.NANOS_PER_MICROS
298301
if (isNegative) -micros else micros
299302
} else {
@@ -307,7 +310,7 @@ object IntervalUtils {
307310
private def parseSecondNano(secondNano: String): Long = {
308311
def parseSeconds(secondsStr: String): Long = {
309312
toLongWithRange(
310-
"second",
313+
SECOND,
311314
secondsStr,
312315
Long.MinValue / DateTimeUtils.MICROS_PER_SECOND,
313316
Long.MaxValue / DateTimeUtils.MICROS_PER_SECOND) * DateTimeUtils.MICROS_PER_SECOND
@@ -366,3 +369,35 @@ object IntervalUtils {
366369
getDuration(interval, TimeUnit.MICROSECONDS, daysPerMonth) < 0
367370
}
368371
}
372+
373+
object IntervalUnit extends Enumeration {
374+
type IntervalUnit = Value
375+
376+
val YEAR, MONTH, WEEK, DAY, HOUR, MINUTE, SECOND, MILLISECOND, MICROSECOND, NANOSECOND = Value
377+
378+
def fromString(unit: String): IntervalUnit = unit.toLowerCase(Locale.ROOT) match {
379+
case "year" | "years" | "y" | "yr" | "yrs" => YEAR
380+
case "month" | "months" | "mon" | "mons" => MONTH
381+
case "week" | "weeks" | "w" => WEEK
382+
case "day" | "days" | "d" => DAY
383+
case "hour" | "hours" | "h" | "hr" | "hrs" => HOUR
384+
case "minute" | "minutes" | "m" | "min" | "mins" => MINUTE
385+
case "second" | "seconds" | "s" | "sec" | "secs" => SECOND
386+
case "millisecond" | "milliseconds" | "ms" | "msec" | "msecs" | "mseconds" => MILLISECOND
387+
case "microsecond" | "microseconds" | "us" | "usec" | "usecs" | "useconds" => MICROSECOND
388+
case u => throw new IllegalArgumentException(s"Invalid interval unit: $u")
389+
}
390+
391+
def fromStringStrict(unit: String): IntervalUnit = unit.toLowerCase(Locale.ROOT) match {
392+
case "year" | "years" => YEAR
393+
case "month" | "months" => MONTH
394+
case "week" | "weeks" => WEEK
395+
case "day" | "days" => DAY
396+
case "hour" | "hours" => HOUR
397+
case "minute" | "minutes" => MINUTE
398+
case "second" | "seconds" => SECOND
399+
case "millisecond" | "milliseconds" => MILLISECOND
400+
case "microsecond" | "microseconds" => MICROSECOND
401+
case u => throw new IllegalArgumentException(s"Invalid interval unit: $u")
402+
}
403+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
2424
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _}
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
27-
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, IntervalUtils}
27+
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, IntervalUnit, IntervalUtils}
2828
import org.apache.spark.sql.internal.SQLConf
2929
import org.apache.spark.sql.types._
3030
import org.apache.spark.unsafe.types.CalendarInterval
@@ -597,7 +597,7 @@ class ExpressionParserSuite extends AnalysisTest {
597597
"microsecond")
598598

599599
def intervalLiteral(u: String, s: String): Literal = {
600-
Literal(IntervalUtils.fromUnitStrings(Array(u), Array(s)))
600+
Literal(IntervalUtils.fromUnitStrings(Array(IntervalUnit.fromString(u)), Array(s)))
601601
}
602602

603603
test("intervals") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class IntervalUtilsSuite extends SparkFunSuite {
145145
}
146146

147147
try {
148-
fromDayTimeString("5 1:12:20", "hour", "microsecond")
148+
fromDayTimeString("5 1:12:20", IntervalUnit.DAY, IntervalUnit.MICROSECOND)
149149
fail("Expected to throw an exception for the invalid convention type")
150150
} catch {
151151
case e: IllegalArgumentException =>

0 commit comments

Comments
 (0)