Skip to content

Commit 829cfe7

Browse files
committed
[SPARK-30341][SQL] Overflow check for interval arithmetic operations
1 parent ab0dd41 commit 829cfe7

File tree

7 files changed

+241
-70
lines changed

7 files changed

+241
-70
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
3737
with ExpectsInputTypes with NullIntolerant {
3838
private val checkOverflow = SQLConf.get.ansiEnabled
3939

40+
override def nullable: Boolean = true
41+
4042
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
4143

4244
override def dataType: DataType = child.dataType
@@ -75,12 +77,29 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
7577
"""})
7678
case _: CalendarIntervalType =>
7779
val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
78-
defineCodeGen(ctx, ev, c => s"$iu.negate($c)")
80+
nullSafeCodeGen(ctx, ev, interval => s"""
81+
try {
82+
${ev.value} = $iu.negate($interval);
83+
} catch (ArithmeticException e) {
84+
if ($checkOverflow) {
85+
throw new ArithmeticException("-($interval) caused interval overflow.");
86+
} else {
87+
${ev.isNull} = true;
88+
}
89+
}
90+
""")
7991
}
8092

8193
protected override def nullSafeEval(input: Any): Any = dataType match {
82-
case CalendarIntervalType => IntervalUtils.negate(input.asInstanceOf[CalendarInterval])
83-
case _ => numeric.negate(input)
94+
case CalendarIntervalType =>
95+
try {
96+
IntervalUtils.negate(input.asInstanceOf[CalendarInterval])
97+
} catch {
98+
case _: ArithmeticException if checkOverflow =>
99+
throw new ArithmeticException(s"$sql caused interval overflow")
100+
case _: ArithmeticException => null
101+
}
102+
case _ => numeric.negate(input)
84103
}
85104

86105
override def sql: String = s"(- ${child.sql})"
@@ -139,6 +158,8 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
139158

140159
override def dataType: DataType = left.dataType
141160

161+
override def nullable: Boolean = true
162+
142163
override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess
143164

144165
/** Name of the function for this expression on a [[Decimal]] type. */
@@ -160,7 +181,19 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
160181
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
161182
case CalendarIntervalType =>
162183
val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
163-
defineCodeGen(ctx, ev, (eval1, eval2) => s"$iu.$calendarIntervalMethod($eval1, $eval2)")
184+
nullSafeCodeGen(ctx, ev, (eval1, eval2) =>
185+
s"""
186+
|try {
187+
| ${ev.value} = $iu.$calendarIntervalMethod($eval1, $eval2);
188+
|} catch (ArithmeticException e) {
189+
| if ($checkOverflow) {
190+
| throw new ArithmeticException(
191+
| "$eval1 $calendarIntervalMethod $eval2 caused interval overflow.");
192+
| } else {
193+
| ${ev.isNull} = true;
194+
| }
195+
|}
196+
|""".stripMargin)
164197
// byte and short are casted into int when add, minus, times or divide
165198
case ByteType | ShortType =>
166199
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
@@ -229,8 +262,15 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
229262
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)
230263

231264
protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
232-
case CalendarIntervalType => IntervalUtils.add(
233-
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
265+
case CalendarIntervalType =>
266+
try {
267+
IntervalUtils.add(
268+
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
269+
} catch {
270+
case _: ArithmeticException if checkOverflow =>
271+
throw new ArithmeticException(s"$sql causes interval overflow")
272+
case _: ArithmeticException => null
273+
}
234274
case _ => numeric.plus(input1, input2)
235275
}
236276

@@ -257,8 +297,15 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
257297
private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow)
258298

259299
protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
260-
case CalendarIntervalType => IntervalUtils.subtract(
261-
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
300+
case CalendarIntervalType =>
301+
try {
302+
IntervalUtils.subtract(
303+
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
304+
} catch {
305+
case _: ArithmeticException if checkOverflow =>
306+
throw new ArithmeticException(s"$sql caused interval overflow")
307+
case _: ArithmeticException => null
308+
}
262309
case _ => numeric.minus(input1, input2)
263310
}
264311

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import scala.util.control.NonFatal
2424
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
2525
import org.apache.spark.sql.catalyst.util.IntervalUtils
2626
import org.apache.spark.sql.catalyst.util.IntervalUtils._
27+
import org.apache.spark.sql.internal.SQLConf
2728
import org.apache.spark.sql.types._
2829
import org.apache.spark.unsafe.types.CalendarInterval
2930

@@ -118,6 +119,8 @@ abstract class IntervalNumOperation(
118119
operation: (CalendarInterval, Double) => CalendarInterval,
119120
operationName: String)
120121
extends BinaryExpression with ImplicitCastInputTypes with Serializable {
122+
private val checkOverflow = SQLConf.get.ansiEnabled
123+
121124
override def left: Expression = interval
122125
override def right: Expression = num
123126

@@ -130,7 +133,9 @@ abstract class IntervalNumOperation(
130133
try {
131134
operation(interval.asInstanceOf[CalendarInterval], num.asInstanceOf[Double])
132135
} catch {
133-
case _: java.lang.ArithmeticException => null
136+
case _: ArithmeticException if checkOverflow =>
137+
throw new ArithmeticException(s"$sql caused interval overflow.")
138+
case _: ArithmeticException => null
134139
}
135140
}
136141

@@ -140,8 +145,12 @@ abstract class IntervalNumOperation(
140145
s"""
141146
try {
142147
${ev.value} = $iu.$operationName($interval, $num);
143-
} catch (java.lang.ArithmeticException e) {
144-
${ev.isNull} = true;
148+
} catch (ArithmeticException e) {
149+
if ($checkOverflow) {
150+
throw new ArithmeticException("$prettyName($interval, $num) caused interval overflow.");
151+
} else {
152+
${ev.isNull} = true;
153+
}
145154
}
146155
"""
147156
})

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -420,26 +420,29 @@ object IntervalUtils {
420420
* @return a new calendar interval instance with all it parameters negated from the origin one.
421421
*/
422422
def negate(interval: CalendarInterval): CalendarInterval = {
423-
new CalendarInterval(-interval.months, -interval.days, -interval.microseconds)
423+
val months = Math.negateExact(interval.months)
424+
val days = Math.negateExact(interval.days)
425+
val microseconds = Math.negateExact(interval.microseconds)
426+
new CalendarInterval(months, days, microseconds)
424427
}
425428

426429
/**
427430
* Return a new calendar interval instance of the sum of two intervals.
428431
*/
429432
def add(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
430-
val months = left.months + right.months
431-
val days = left.days + right.days
432-
val microseconds = left.microseconds + right.microseconds
433+
val months = Math.addExact(left.months, right.months)
434+
val days = Math.addExact(left.days, right.days)
435+
val microseconds = Math.addExact(left.microseconds, right.microseconds)
433436
new CalendarInterval(months, days, microseconds)
434437
}
435438

436439
/**
437440
* Return a new calendar interval instance of the left intervals minus the right one.
438441
*/
439442
def subtract(left: CalendarInterval, right: CalendarInterval): CalendarInterval = {
440-
val months = left.months - right.months
441-
val days = left.days - right.days
442-
val microseconds = left.microseconds - right.microseconds
443+
val months = Math.subtractExact(left.months, right.months)
444+
val days = Math.subtractExact(left.days, right.days)
445+
val microseconds = Math.subtractExact(left.microseconds, right.microseconds)
443446
new CalendarInterval(months, days, microseconds)
444447
}
445448

@@ -448,7 +451,7 @@ object IntervalUtils {
448451
}
449452

450453
def divide(interval: CalendarInterval, num: Double): CalendarInterval = {
451-
if (num == 0) throw new java.lang.ArithmeticException("divide by zero")
454+
if (num == 0) throw new ArithmeticException("divide by zero")
452455
fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num)
453456
}
454457

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,4 +445,24 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper {
445445
checkFail("5 30-12", DAY, SECOND, "must match day-time format")
446446
checkFail("5 1:12:20", HOUR, MICROSECOND, "Cannot support (interval")
447447
}
448+
449+
test("interval overflow check") {
450+
intercept[ArithmeticException](negate(new CalendarInterval(Int.MinValue, 0, 0)))
451+
intercept[ArithmeticException](negate(CalendarInterval.MIN_VALUE))
452+
453+
intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, 1)))
454+
intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 1, 0)))
455+
intercept[ArithmeticException](add(CalendarInterval.MAX_VALUE, new CalendarInterval(1, 0, 0)))
456+
457+
intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE,
458+
new CalendarInterval(0, 0, -1)))
459+
intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE,
460+
new CalendarInterval(0, -1, 0)))
461+
intercept[ArithmeticException](subtract(CalendarInterval.MAX_VALUE,
462+
new CalendarInterval(-1, 0, 0)))
463+
464+
intercept[ArithmeticException](multiply(CalendarInterval.MAX_VALUE, 2))
465+
466+
intercept[ArithmeticException](divide(CalendarInterval.MAX_VALUE, 0.5))
467+
}
448468
}

sql/core/src/test/resources/sql-tests/inputs/interval.sql

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,10 @@ select interval 'interval \t 1\tday';
269269
select interval 'interval\t1\tday';
270270
select interval '1\t' day;
271271
select interval '1 ' day;
272+
273+
-- interval overflow if (ansi) exception else NULL
274+
select -(a) from values (interval '-2147483648 months', interval '2147483647 months') t(a, b);
275+
select a - b from values (interval '-2147483648 months', interval '2147483647 months') t(a, b);
276+
select b + interval '1 month' from values (interval '-2147483648 months', interval '2147483647 months') t(a, b);
277+
select a * 2 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b);
278+
select a / 0.5 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b);

0 commit comments

Comments
 (0)