-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-29605][SQL] Optimize string to interval casting #26256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b22f243
c33710f
ffa1bee
a4c16a1
f249f49
d077c87
80f5a06
479d5bd
d68f41e
8e8e539
9dfb45d
8c3fb28
bc006a2
7515981
78a2e8e
f61e6f8
dd8f2d1
1204656
a2d91c3
94bd39b
98dd44f
0cd7e88
107d16c
8dd9518
464eacc
dbad971
527b00e
5e96a0e
2222f13
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ import scala.util.control.NonFatal | |
|
|
||
| import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} | ||
| import org.apache.spark.sql.types.Decimal | ||
| import org.apache.spark.unsafe.types.CalendarInterval | ||
| import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} | ||
|
|
||
| object IntervalUtils { | ||
| final val MONTHS_PER_YEAR: Int = 12 | ||
|
|
@@ -39,6 +39,7 @@ object IntervalUtils { | |
| final val MICROS_PER_MONTH: Long = DAYS_PER_MONTH * DateTimeUtils.MICROS_PER_DAY | ||
| /* 365.25 days per year assumes leap year every four years */ | ||
| final val MICROS_PER_YEAR: Long = (36525L * DateTimeUtils.MICROS_PER_DAY) / 100 | ||
| final val DAYS_PER_WEEK: Byte = 7 | ||
|
|
||
| def getYears(interval: CalendarInterval): Int = { | ||
| interval.months / MONTHS_PER_YEAR | ||
|
|
@@ -388,4 +389,194 @@ object IntervalUtils { | |
| def divide(interval: CalendarInterval, num: Double): CalendarInterval = { | ||
| fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num) | ||
| } | ||
|
|
||
| private object ParseState extends Enumeration { | ||
| val PREFIX, | ||
| BEGIN_VALUE, | ||
| PARSE_SIGN, | ||
| PARSE_UNIT_VALUE, | ||
| FRACTIONAL_PART, | ||
| BEGIN_UNIT_NAME, | ||
| UNIT_NAME_SUFFIX, | ||
| END_UNIT_NAME = Value | ||
| } | ||
| private final val intervalStr = UTF8String.fromString("interval ") | ||
| private final val yearStr = UTF8String.fromString("year") | ||
| private final val monthStr = UTF8String.fromString("month") | ||
| private final val weekStr = UTF8String.fromString("week") | ||
| private final val dayStr = UTF8String.fromString("day") | ||
| private final val hourStr = UTF8String.fromString("hour") | ||
| private final val minuteStr = UTF8String.fromString("minute") | ||
| private final val secondStr = UTF8String.fromString("second") | ||
| private final val millisStr = UTF8String.fromString("millisecond") | ||
| private final val microsStr = UTF8String.fromString("microsecond") | ||
|
|
||
| def stringToInterval(input: UTF8String): CalendarInterval = { | ||
| import ParseState._ | ||
|
|
||
| if (input == null) { | ||
| return null | ||
| } | ||
| // scalastyle:off caselocale .toLowerCase | ||
| val s = input.trim.toLowerCase | ||
| // scalastyle:on | ||
| val bytes = s.getBytes | ||
| if (bytes.length == 0) { | ||
| return null | ||
| } | ||
| var state = PREFIX | ||
| var i = 0 | ||
| var currentValue: Long = 0 | ||
| var isNegative: Boolean = false | ||
| var months: Int = 0 | ||
| var days: Int = 0 | ||
| var microseconds: Long = 0 | ||
| var fractionScale: Int = 0 | ||
| var fraction: Int = 0 | ||
|
|
||
| while (i < bytes.length) { | ||
| val b = bytes(i) | ||
| state match { | ||
| case PREFIX => | ||
| if (s.startsWith(intervalStr)) { | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if (s.numBytes() == intervalStr.numBytes()) { | ||
| return null | ||
| } else { | ||
| i += intervalStr.numBytes() | ||
| } | ||
| } | ||
| state = BEGIN_VALUE | ||
| case BEGIN_VALUE => | ||
| b match { | ||
| case ' ' => i += 1 | ||
| case _ => state = PARSE_SIGN | ||
| } | ||
| case PARSE_SIGN => | ||
| b match { | ||
| case '-' => | ||
| isNegative = true | ||
| i += 1 | ||
| case '+' => | ||
| isNegative = false | ||
| i += 1 | ||
| case _ if '0' <= b && b <= '9' => | ||
| isNegative = false | ||
| case _ => return null | ||
| } | ||
| currentValue = 0 | ||
| fraction = 0 | ||
| // Sets the scale to an invalid value to track fraction presence | ||
| // in the BEGIN_UNIT_NAME state | ||
| fractionScale = -1 | ||
| state = PARSE_UNIT_VALUE | ||
| case PARSE_UNIT_VALUE => | ||
| b match { | ||
| case _ if '0' <= b && b <= '9' => | ||
| try { | ||
| currentValue = Math.addExact(Math.multiplyExact(10, currentValue), (b - '0')) | ||
| } catch { | ||
| case _: ArithmeticException => return null | ||
| } | ||
| case ' ' => | ||
| state = BEGIN_UNIT_NAME | ||
| case '.' => | ||
| fractionScale = (DateTimeUtils.NANOS_PER_SECOND / 10).toInt | ||
| state = FRACTIONAL_PART | ||
| case _ => return null | ||
| } | ||
| i += 1 | ||
| case FRACTIONAL_PART => | ||
| b match { | ||
| case _ if '0' <= b && b <= '9' && fractionScale > 0 => | ||
| fraction += (b - '0') * fractionScale | ||
| fractionScale /= 10 | ||
| case ' ' => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. spark-sql> select cast('1. seconds' as interval);
interval 1 seconds
spark-sql> select cast('1. days' as interval);
interval 1 days |
||
| fraction /= DateTimeUtils.NANOS_PER_MICROS.toInt | ||
| state = BEGIN_UNIT_NAME | ||
| case _ => return null | ||
| } | ||
| i += 1 | ||
| case BEGIN_UNIT_NAME => | ||
| if (b == ' ') { | ||
| i += 1 | ||
| } else { | ||
| // Checks that only seconds can have the fractional part | ||
| if (b != 's' && fractionScale >= 0) { | ||
| return null | ||
| } | ||
| if (isNegative) { | ||
| currentValue = -currentValue | ||
| fraction = -fraction | ||
| } | ||
| try { | ||
| b match { | ||
| case 'y' if s.matchAt(yearStr, i) => | ||
| val monthsInYears = Math.multiplyExact(MONTHS_PER_YEAR, currentValue) | ||
| months = Math.toIntExact(Math.addExact(months, monthsInYears)) | ||
| i += yearStr.numBytes() | ||
| case 'w' if s.matchAt(weekStr, i) => | ||
| val daysInWeeks = Math.multiplyExact(DAYS_PER_WEEK, currentValue) | ||
| days = Math.toIntExact(Math.addExact(days, daysInWeeks)) | ||
| i += weekStr.numBytes() | ||
| case 'd' if s.matchAt(dayStr, i) => | ||
| days = Math.addExact(days, Math.toIntExact(currentValue)) | ||
| i += dayStr.numBytes() | ||
| case 'h' if s.matchAt(hourStr, i) => | ||
| val hoursUs = Math.multiplyExact(currentValue, MICROS_PER_HOUR) | ||
| microseconds = Math.addExact(microseconds, hoursUs) | ||
| i += hourStr.numBytes() | ||
| case 's' if s.matchAt(secondStr, i) => | ||
| val secondsUs = Math.multiplyExact(currentValue, DateTimeUtils.MICROS_PER_SECOND) | ||
| microseconds = Math.addExact(Math.addExact(microseconds, secondsUs), fraction) | ||
| i += secondStr.numBytes() | ||
| case 'm' => | ||
| if (s.matchAt(monthStr, i)) { | ||
| months = Math.addExact(months, Math.toIntExact(currentValue)) | ||
| i += monthStr.numBytes() | ||
| } else if (s.matchAt(minuteStr, i)) { | ||
| val minutesUs = Math.multiplyExact(currentValue, MICROS_PER_MINUTE) | ||
| microseconds = Math.addExact(microseconds, minutesUs) | ||
| i += minuteStr.numBytes() | ||
| } else if (s.matchAt(millisStr, i)) { | ||
| val millisUs = Math.multiplyExact( | ||
| currentValue, | ||
| DateTimeUtils.MICROS_PER_MILLIS) | ||
| microseconds = Math.addExact(microseconds, millisUs) | ||
| i += millisStr.numBytes() | ||
| } else if (s.matchAt(microsStr, i)) { | ||
| microseconds = Math.addExact(microseconds, currentValue) | ||
| i += microsStr.numBytes() | ||
| } else return null | ||
| case _ => return null | ||
| } | ||
| } catch { | ||
| case _: ArithmeticException => return null | ||
| } | ||
| state = UNIT_NAME_SUFFIX | ||
| } | ||
| case UNIT_NAME_SUFFIX => | ||
| b match { | ||
| case 's' => state = END_UNIT_NAME | ||
| case ' ' => state = BEGIN_VALUE | ||
| case _ => return null | ||
| } | ||
| i += 1 | ||
| case END_UNIT_NAME => | ||
| b match { | ||
| case ' ' => | ||
| i += 1 | ||
| state = BEGIN_VALUE | ||
| case _ => return null | ||
| } | ||
| } | ||
| } | ||
|
|
||
| val result = state match { | ||
| case UNIT_NAME_SUFFIX | END_UNIT_NAME | BEGIN_VALUE => | ||
| new CalendarInterval(months, days, microseconds) | ||
| case _ => null | ||
| } | ||
|
|
||
| result | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,11 +22,39 @@ import java.util.concurrent.TimeUnit | |
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.sql.catalyst.util.DateTimeUtils.{MICROS_PER_MILLIS, MICROS_PER_SECOND} | ||
| import org.apache.spark.sql.catalyst.util.IntervalUtils._ | ||
| import org.apache.spark.unsafe.types.CalendarInterval | ||
| import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} | ||
|
|
||
| class IntervalUtilsSuite extends SparkFunSuite { | ||
|
|
||
| test("fromString: basic") { | ||
| private def checkFromString(input: String, expected: CalendarInterval): Unit = { | ||
| assert(fromString(input) === expected) | ||
| assert(stringToInterval(UTF8String.fromString(input)) === expected) | ||
| } | ||
|
|
||
| private def checkFromInvalidString(input: String, errorMsg: String): Unit = { | ||
| try { | ||
| fromString(input) | ||
| fail("Expected to throw an exception for the invalid input") | ||
| } catch { | ||
| case e: IllegalArgumentException => | ||
| val msg = e.getMessage | ||
| assert(msg.contains(errorMsg)) | ||
| } | ||
| assert(stringToInterval(UTF8String.fromString(input)) === null) | ||
| } | ||
|
|
||
| private def testSingleUnit( | ||
| unit: String, number: Int, months: Int, days: Int, microseconds: Long): Unit = { | ||
| for (prefix <- Seq("interval ", "")) { | ||
| val input1 = prefix + number + " " + unit | ||
| val input2 = prefix + number + " " + unit + "s" | ||
| val result = new CalendarInterval(months, days, microseconds) | ||
| checkFromString(input1, result) | ||
| checkFromString(input2, result) | ||
| } | ||
| } | ||
|
|
||
| test("string to interval: basic") { | ||
| testSingleUnit("YEAR", 3, 36, 0, 0) | ||
| testSingleUnit("Month", 3, 3, 0, 0) | ||
| testSingleUnit("Week", 3, 0, 21, 0) | ||
|
|
@@ -37,58 +65,46 @@ class IntervalUtilsSuite extends SparkFunSuite { | |
| testSingleUnit("MilliSecond", 3, 0, 0, 3 * MICROS_PER_MILLIS) | ||
| testSingleUnit("MicroSecond", 3, 0, 0, 3) | ||
|
|
||
| for (input <- Seq(null, "", " ")) { | ||
| try { | ||
| fromString(input) | ||
| fail("Expected to throw an exception for the invalid input") | ||
| } catch { | ||
| case e: IllegalArgumentException => | ||
| val msg = e.getMessage | ||
| if (input == null) { | ||
| assert(msg.contains("cannot be null")) | ||
| } | ||
| } | ||
| } | ||
| checkFromInvalidString(null, "cannot be null") | ||
|
|
||
| for (input <- Seq("interval", "interval1 day", "foo", "foo 1 day")) { | ||
| try { | ||
| fromString(input) | ||
| fail("Expected to throw an exception for the invalid input") | ||
| } catch { | ||
| case e: IllegalArgumentException => | ||
| val msg = e.getMessage | ||
| assert(msg.contains("Invalid interval string")) | ||
| } | ||
| for (input <- Seq("", " ", "interval", "interval1 day", "foo", "foo 1 day")) { | ||
| checkFromInvalidString(input, "Invalid interval string") | ||
| } | ||
| } | ||
|
|
||
| test("fromString: random order field") { | ||
| val input = "1 day 1 year" | ||
| val result = new CalendarInterval(12, 1, 0) | ||
| assert(fromString(input) == result) | ||
| } | ||
|
|
||
| test("fromString: duplicated fields") { | ||
| val input = "1 day 1 day" | ||
| val result = new CalendarInterval(0, 2, 0) | ||
| assert(fromString(input) == result) | ||
| test("string to interval: multiple units") { | ||
| Seq( | ||
| "-1 MONTH 1 day -1 microseconds" -> new CalendarInterval(-1, 1, -1), | ||
| " 123 MONTHS 123 DAYS 123 Microsecond " -> new CalendarInterval(123, 123, 123), | ||
| "interval -1 day +3 Microseconds" -> new CalendarInterval(0, -1, 3), | ||
| " interval 8 years -11 months 123 weeks -1 day " + | ||
| "23 hours -22 minutes 1 second -123 millisecond 567 microseconds " -> | ||
| new CalendarInterval(85, 860, 81480877567L)).foreach { case (input, expected) => | ||
| checkFromString(input, expected) | ||
| } | ||
| } | ||
|
|
||
| test("fromString: value with +/-") { | ||
| val input = "+1 year -1 day" | ||
| val result = new CalendarInterval(12, -1, 0) | ||
| assert(fromString(input) == result) | ||
| test("string to interval: special cases") { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we also test |
||
| // Support any order of interval units | ||
| checkFromString("1 day 1 year", new CalendarInterval(12, 1, 0)) | ||
| // Allow duplicated units and summarize their values | ||
| checkFromString("1 day 10 day", new CalendarInterval(0, 11, 0)) | ||
| // Only the seconds units can have the fractional part | ||
| checkFromInvalidString("1.5 days", "Error parsing interval string") | ||
| checkFromInvalidString("1. hour", "Error parsing interval string") | ||
| } | ||
|
|
||
| private def testSingleUnit( | ||
| unit: String, number: Int, months: Int, days: Int, microseconds: Long): Unit = { | ||
| for (prefix <- Seq("interval ", "")) { | ||
| val input1 = prefix + number + " " + unit | ||
| val input2 = prefix + number + " " + unit + "s" | ||
| val result = new CalendarInterval(months, days, microseconds) | ||
| assert(fromString(input1) == result) | ||
| assert(fromString(input2) == result) | ||
| } | ||
| test("string to interval: seconds with fractional part") { | ||
| checkFromString("0.1 seconds", new CalendarInterval(0, 0, 100000)) | ||
| checkFromString("1. seconds", new CalendarInterval(0, 0, 1000000)) | ||
| checkFromString("123.001 seconds", new CalendarInterval(0, 0, 123001000)) | ||
| checkFromString("1.001001 seconds", new CalendarInterval(0, 0, 1001001)) | ||
| checkFromString("1 minute 1.001001 seconds", new CalendarInterval(0, 0, 61001001)) | ||
| checkFromString("-1.5 seconds", new CalendarInterval(0, 0, -1500000)) | ||
| // truncate nanoseconds to microseconds | ||
| checkFromString("0.999999999 seconds", new CalendarInterval(0, 0, 999999)) | ||
| checkFromInvalidString("0.123456789123 seconds", "Error parsing interval string") | ||
| } | ||
|
|
||
| test("from year-month string") { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know you added
finalfor performance reasons, but do we have actual performance diffs of the benchmarks below with/without thisfinal? (Just out of curiosity...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I re-ran
IntervalBenchmarkwithoutfinals, there is no difference, actually.