1717
1818package org .apache .spark .sql .catalyst .util
1919
20+ import java .util .Locale
2021import java .util .concurrent .TimeUnit
2122
2223import scala .util .control .NonFatal
@@ -26,6 +27,8 @@ import org.apache.spark.sql.types.Decimal
2627import org .apache .spark .unsafe .types .CalendarInterval
2728
2829object IntervalUtils {
30+ import IntervalUnit ._
31+
2932 final val MONTHS_PER_YEAR : Int = 12
3033 final val MONTHS_PER_QUARTER : Byte = 3
3134 final val YEARS_PER_MILLENNIUM : Int = 1000
@@ -126,13 +129,13 @@ object IntervalUtils {
126129 }
127130
128131 private def toLongWithRange (
129- fieldName : String ,
132+ fieldName : IntervalUnit ,
130133 s : String ,
131134 minValue : Long ,
132135 maxValue : Long ): Long = {
133136 val result = if (s == null ) 0L else s.toLong
134137 require(minValue <= result && result <= maxValue,
135- s " $fieldName $result outside range [ $minValue, $maxValue] " )
138+ s " ${ fieldName.toString} $result outside range [ $minValue, $maxValue] " )
136139
137140 result
138141 }
@@ -148,8 +151,8 @@ object IntervalUtils {
148151 require(input != null , " Interval year-month string must be not null" )
149152 def toInterval (yearStr : String , monthStr : String ): CalendarInterval = {
150153 try {
151- val years = toLongWithRange(" year " , yearStr, 0 , Integer .MAX_VALUE ).toInt
152- val months = toLongWithRange(" month " , monthStr, 0 , 11 ).toInt
154+ val years = toLongWithRange(YEAR , yearStr, 0 , Integer .MAX_VALUE ).toInt
155+ val months = toLongWithRange(MONTH , monthStr, 0 , 11 ).toInt
153156 val totalMonths = Math .addExact(Math .multiplyExact(years, 12 ), months)
154157 new CalendarInterval (totalMonths, 0 , 0 )
155158 } catch {
@@ -176,7 +179,7 @@ object IntervalUtils {
176179 * adapted from HiveIntervalDayTime.valueOf
177180 */
178181 def fromDayTimeString (s : String ): CalendarInterval = {
179- fromDayTimeString(s, " day " , " second " )
182+ fromDayTimeString(s, DAY , SECOND )
180183 }
181184
182185 private val dayTimePattern =
@@ -191,7 +194,7 @@ object IntervalUtils {
191194 * - HOUR TO (MINUTE|SECOND)
192195 * - MINUTE TO SECOND
193196 */
194- def fromDayTimeString (input : String , from : String , to : String ): CalendarInterval = {
197+ def fromDayTimeString (input : String , from : IntervalUnit , to : IntervalUnit ): CalendarInterval = {
195198 require(input != null , " Interval day-time string must be not null" )
196199 assert(input.length == input.trim.length)
197200 val m = dayTimePattern.pattern.matcher(input)
@@ -202,34 +205,34 @@ object IntervalUtils {
202205 val days = if (m.group(2 ) == null ) {
203206 0
204207 } else {
205- toLongWithRange(" day " , m.group(3 ), 0 , Integer .MAX_VALUE ).toInt
208+ toLongWithRange(DAY , m.group(3 ), 0 , Integer .MAX_VALUE ).toInt
206209 }
207210 var hours : Long = 0L
208211 var minutes : Long = 0L
209212 var seconds : Long = 0L
210- if (m.group(5 ) != null || from == " minute " ) { // 'HH:mm:ss' or 'mm:ss minute'
211- hours = toLongWithRange(" hour " , m.group(5 ), 0 , 23 )
212- minutes = toLongWithRange(" minute " , m.group(6 ), 0 , 59 )
213- seconds = toLongWithRange(" second " , m.group(7 ), 0 , 59 )
213+ if (m.group(5 ) != null || from == MINUTE ) { // 'HH:mm:ss' or 'mm:ss minute'
214+ hours = toLongWithRange(HOUR , m.group(5 ), 0 , 23 )
215+ minutes = toLongWithRange(MINUTE , m.group(6 ), 0 , 59 )
216+ seconds = toLongWithRange(SECOND , m.group(7 ), 0 , 59 )
214217 } else if (m.group(8 ) != null ) { // 'mm:ss.nn'
215- minutes = toLongWithRange(" minute " , m.group(6 ), 0 , 59 )
216- seconds = toLongWithRange(" second " , m.group(7 ), 0 , 59 )
218+ minutes = toLongWithRange(MINUTE , m.group(6 ), 0 , 59 )
219+ seconds = toLongWithRange(SECOND , m.group(7 ), 0 , 59 )
217220 } else { // 'HH:mm'
218- hours = toLongWithRange(" hour " , m.group(6 ), 0 , 23 )
219- minutes = toLongWithRange(" second " , m.group(7 ), 0 , 59 )
221+ hours = toLongWithRange(HOUR , m.group(6 ), 0 , 23 )
222+ minutes = toLongWithRange(SECOND , m.group(7 ), 0 , 59 )
220223 }
221224 // Hive allow nanosecond precision interval
222225 var secondsFraction = parseNanos(m.group(9 ), seconds < 0 )
223226 to match {
224- case " hour " =>
227+ case HOUR =>
225228 minutes = 0
226229 seconds = 0
227230 secondsFraction = 0
228- case " minute " =>
231+ case MINUTE =>
229232 seconds = 0
230233 secondsFraction = 0
231- case " second " =>
232- // No-op
234+ case SECOND =>
235+ // No-op
233236 case _ =>
234237 throw new IllegalArgumentException (
235238 s " Cannot support (interval ' $input' $from to $to) expression " )
@@ -246,36 +249,36 @@ object IntervalUtils {
246249 }
247250 }
248251
249- def fromUnitStrings (units : Array [String ], values : Array [String ]): CalendarInterval = {
250- assert(units.length == values .length)
252+ def fromUnitStrings (units : Array [IntervalUnit ], fields : Array [String ]): CalendarInterval = {
253+ assert(units.length == fields .length)
251254 var months : Int = 0
252255 var days : Int = 0
253256 var microseconds : Long = 0
254257 var i = 0
255258 while (i < units.length) {
256259 try {
257260 units(i) match {
258- case " year " =>
259- months = Math .addExact(months, Math .multiplyExact(values (i).toInt, 12 ))
260- case " month " =>
261- months = Math .addExact(months, values (i).toInt)
262- case " week " =>
263- days = Math .addExact(days, Math .multiplyExact(values (i).toInt, 7 ))
264- case " day " =>
265- days = Math .addExact(days, values (i).toInt)
266- case " hour " =>
267- val hoursUs = Math .multiplyExact(values (i).toLong, MICROS_PER_HOUR )
261+ case YEAR =>
262+ months = Math .addExact(months, Math .multiplyExact(fields (i).toInt, 12 ))
263+ case MONTH =>
264+ months = Math .addExact(months, fields (i).toInt)
265+ case WEEK =>
266+ days = Math .addExact(days, Math .multiplyExact(fields (i).toInt, 7 ))
267+ case DAY =>
268+ days = Math .addExact(days, fields (i).toInt)
269+ case HOUR =>
270+ val hoursUs = Math .multiplyExact(fields (i).toLong, MICROS_PER_HOUR )
268271 microseconds = Math .addExact(microseconds, hoursUs)
269- case " minute " =>
270- val minutesUs = Math .multiplyExact(values (i).toLong, MICROS_PER_MINUTE )
272+ case MINUTE =>
273+ val minutesUs = Math .multiplyExact(fields (i).toLong, MICROS_PER_MINUTE )
271274 microseconds = Math .addExact(microseconds, minutesUs)
272- case " second " =>
273- microseconds = Math .addExact(microseconds, parseSecondNano(values (i)))
274- case " millisecond " =>
275- val millisUs = Math .multiplyExact(values (i).toLong, DateTimeUtils .MICROS_PER_MILLIS )
275+ case SECOND =>
276+ microseconds = Math .addExact(microseconds, parseSecondNano(fields (i)))
277+ case MILLISECOND =>
278+ val millisUs = Math .multiplyExact(fields (i).toLong, DateTimeUtils .MICROS_PER_MILLIS )
276279 microseconds = Math .addExact(microseconds, millisUs)
277- case " microsecond " =>
278- microseconds = Math .addExact(microseconds, values (i).toLong)
280+ case MICROSECOND =>
281+ microseconds = Math .addExact(microseconds, fields (i).toLong)
279282 }
280283 } catch {
281284 case e : Exception =>
@@ -293,7 +296,7 @@ object IntervalUtils {
293296 val alignedStr = if (nanosStr.length < maxNanosLen) {
294297 (nanosStr + " 000000000" ).substring(0 , maxNanosLen)
295298 } else nanosStr
296- val nanos = toLongWithRange(" nanosecond " , alignedStr, 0L , 999999999L )
299+ val nanos = toLongWithRange(NANOSECOND , alignedStr, 0L , 999999999L )
297300 val micros = nanos / DateTimeUtils .NANOS_PER_MICROS
298301 if (isNegative) - micros else micros
299302 } else {
@@ -307,7 +310,7 @@ object IntervalUtils {
307310 private def parseSecondNano (secondNano : String ): Long = {
308311 def parseSeconds (secondsStr : String ): Long = {
309312 toLongWithRange(
310- " second " ,
313+ SECOND ,
311314 secondsStr,
312315 Long .MinValue / DateTimeUtils .MICROS_PER_SECOND ,
313316 Long .MaxValue / DateTimeUtils .MICROS_PER_SECOND ) * DateTimeUtils .MICROS_PER_SECOND
@@ -366,3 +369,35 @@ object IntervalUtils {
366369 getDuration(interval, TimeUnit .MICROSECONDS , daysPerMonth) < 0
367370 }
368371}
372+
373+ object IntervalUnit extends Enumeration {
374+ type IntervalUnit = Value
375+
376+ val YEAR, MONTH, WEEK, DAY, HOUR, MINUTE, SECOND, MILLISECOND, MICROSECOND, NANOSECOND = Value
377+
378+ def fromString (unit : String ): IntervalUnit = unit.toLowerCase(Locale .ROOT ) match {
379+ case " year" | " years" | " y" | " yr" | " yrs" => YEAR
380+ case " month" | " months" | " mon" | " mons" => MONTH
381+ case " week" | " weeks" | " w" => WEEK
382+ case " day" | " days" | " d" => DAY
383+ case " hour" | " hours" | " h" | " hr" | " hrs" => HOUR
384+ case " minute" | " minutes" | " m" | " min" | " mins" => MINUTE
385+ case " second" | " seconds" | " s" | " sec" | " secs" => SECOND
386+ case " millisecond" | " milliseconds" | " ms" | " msec" | " msecs" | " mseconds" => MILLISECOND
387+ case " microsecond" | " microseconds" | " us" | " usec" | " usecs" | " useconds" => MICROSECOND
388+ case u => throw new IllegalArgumentException (s " Invalid interval unit: $u" )
389+ }
390+
391+ def fromStringStrict (unit : String ): IntervalUnit = unit.toLowerCase(Locale .ROOT ) match {
392+ case " year" | " years" => YEAR
393+ case " month" | " months" => MONTH
394+ case " week" | " weeks" => WEEK
395+ case " day" | " days" => DAY
396+ case " hour" | " hours" => HOUR
397+ case " minute" | " minutes" => MINUTE
398+ case " second" | " seconds" => SECOND
399+ case " millisecond" | " milliseconds" => MILLISECOND
400+ case " microsecond" | " microseconds" => MICROSECOND
401+ case u => throw new IllegalArgumentException (s " Invalid interval unit: $u" )
402+ }
403+ }
0 commit comments