Skip to content

Commit f5900a5

Browse files
thepinetreedongjoon-hyun
authored andcommitted
[SPARK-43393][SQL][3.4] Address sequence expression overflow bug
### What changes were proposed in this pull request? Spark has a (long-standing) overflow bug in the `sequence` expression. Consider the following operations: ``` spark.sql("CREATE TABLE foo (l LONG);") spark.sql(s"INSERT INTO foo VALUES (${Long.MaxValue});") spark.sql("SELECT sequence(0, l) FROM foo;").collect() ``` The result of these operations will be: ``` Array[org.apache.spark.sql.Row] = Array([WrappedArray()]) ``` an unintended consequence of overflow. The sequence is applied to values `0` and `Long.MaxValue` with a step size of `1` which uses a length computation defined [here](https://github.com/apache/spark/blob/16411188c7ba6cb19c46a2bd512b2485a4c03e2c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L3451). In this calculation, with `start = 0`, `stop = Long.MaxValue`, and `step = 1`, the calculated `len` overflows to `Long.MinValue`. The computation, in binary looks like: ``` 0111111111111111111111111111111111111111111111111111111111111111 - 0000000000000000000000000000000000000000000000000000000000000000 ------------------------------------------------------------------ 0111111111111111111111111111111111111111111111111111111111111111 / 0000000000000000000000000000000000000000000000000000000000000001 ------------------------------------------------------------------ 0111111111111111111111111111111111111111111111111111111111111111 + 0000000000000000000000000000000000000000000000000000000000000001 ------------------------------------------------------------------ 1000000000000000000000000000000000000000000000000000000000000000 ``` The following [check](https://github.com/apache/spark/blob/16411188c7ba6cb19c46a2bd512b2485a4c03e2c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L3454) passes as the negative `Long.MinValue` is still `<= MAX_ROUNDED_ARRAY_LENGTH`. The following cast to `toInt` uses this representation and [truncates the upper bits](https://github.com/apache/spark/blob/16411188c7ba6cb19c46a2bd512b2485a4c03e2c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L3457) resulting in an empty length of `0`. Other overflows are similarly problematic. This PR addresses the issue by checking numeric operations in the length computation for overflow. ### Why are the changes needed? There is a correctness bug from overflow in the `sequence` expression. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Tests added in `CollectionExpressionsSuite.scala`. Closes #43819 from thepinetree/spark-sequence-overflow-3.4. Authored-by: Deepayan Patra <deepayan.patra@databricks.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
1 parent 23f15af commit f5900a5

File tree

2 files changed

+71
-20
lines changed

2 files changed

+71
-20
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.Comparator
2222
import scala.collection.mutable
2323
import scala.reflect.ClassTag
2424

25+
import org.apache.spark.SparkException.internalError
2526
import org.apache.spark.sql.catalyst.InternalRow
2627
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed}
2728
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -39,7 +40,6 @@ import org.apache.spark.sql.types._
3940
import org.apache.spark.sql.util.SQLOpenHashSet
4041
import org.apache.spark.unsafe.UTF8StringBuilder
4142
import org.apache.spark.unsafe.array.ByteArrayMethods
42-
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
4343
import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String}
4444

4545
/**
@@ -3011,6 +3011,34 @@ case class Sequence(
30113011
}
30123012

30133013
object Sequence {
3014+
private def prettyName: String = "sequence"
3015+
3016+
def sequenceLength(start: Long, stop: Long, step: Long): Int = {
3017+
try {
3018+
val delta = Math.subtractExact(stop, start)
3019+
if (delta == Long.MinValue && step == -1L) {
3020+
// We must special-case division of Long.MinValue by -1 to catch potential unchecked
3021+
// overflow in next operation. Division does not have a builtin overflow check. We
3022+
// previously special-case div-by-zero.
3023+
throw new ArithmeticException("Long overflow (Long.MinValue / -1)")
3024+
}
3025+
val len = if (stop == start) 1L else Math.addExact(1L, (delta / step))
3026+
if (len > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
3027+
throw QueryExecutionErrors.createArrayWithElementsExceedLimitError(len)
3028+
}
3029+
len.toInt
3030+
} catch {
3031+
// We handle overflows in the previous try block by raising an appropriate exception.
3032+
case _: ArithmeticException =>
3033+
val safeLen =
3034+
BigInt(1) + (BigInt(stop) - BigInt(start)) / BigInt(step)
3035+
if (safeLen > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
3036+
throw QueryExecutionErrors.createArrayWithElementsExceedLimitError(safeLen)
3037+
}
3038+
throw internalError("Unreachable code reached.")
3039+
case e: Exception => throw e
3040+
}
3041+
}
30143042

30153043
private type LessThanOrEqualFn = (Any, Any) => Boolean
30163044

@@ -3382,13 +3410,7 @@ object Sequence {
33823410
|| (estimatedStep == num.zero && start == stop),
33833411
s"Illegal sequence boundaries: $start to $stop by $step")
33843412

3385-
val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / estimatedStep.toLong
3386-
3387-
require(
3388-
len <= MAX_ROUNDED_ARRAY_LENGTH,
3389-
s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
3390-
3391-
len.toInt
3413+
sequenceLength(start.toLong, stop.toLong, estimatedStep.toLong)
33923414
}
33933415

33943416
private def genSequenceLengthCode(
@@ -3398,20 +3420,15 @@ object Sequence {
33983420
step: String,
33993421
estimatedStep: String,
34003422
len: String): String = {
3401-
val longLen = ctx.freshName("longLen")
3423+
val calcFn = classOf[Sequence].getName + ".sequenceLength"
34023424
s"""
34033425
|if (!(($estimatedStep > 0 && $start <= $stop) ||
34043426
| ($estimatedStep < 0 && $start >= $stop) ||
34053427
| ($estimatedStep == 0 && $start == $stop))) {
34063428
| throw new IllegalArgumentException(
34073429
| "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step);
34083430
|}
3409-
|long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $estimatedStep;
3410-
|if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) {
3411-
| throw new IllegalArgumentException(
3412-
| "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH");
3413-
|}
3414-
|int $len = (int) $longLen;
3431+
|int $len = $calcFn((long) $start, (long) $stop, (long) $estimatedStep);
34153432
""".stripMargin
34163433
}
34173434
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{outstandingZoneIds,
3434
import org.apache.spark.sql.catalyst.util.IntervalUtils._
3535
import org.apache.spark.sql.internal.SQLConf
3636
import org.apache.spark.sql.types._
37-
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
37+
import org.apache.spark.unsafe.array.ByteArrayMethods
3838
import org.apache.spark.unsafe.types.UTF8String
3939

4040
class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -769,10 +769,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
769769

770770
// test sequence boundaries checking
771771

772-
checkExceptionInExpression[IllegalArgumentException](
773-
new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
774-
EmptyRow, s"Too long sequence: 4294967296. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
775-
776772
checkExceptionInExpression[IllegalArgumentException](
777773
new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries: 1 to 2 by 0")
778774
checkExceptionInExpression[IllegalArgumentException](
@@ -782,6 +778,44 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
782778
checkExceptionInExpression[IllegalArgumentException](
783779
new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow, "boundaries: 1 to 2 by -1")
784780

781+
// SPARK-43393: test Sequence overflow checking
782+
checkErrorInExpression[SparkRuntimeException](
783+
new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
784+
errorClass = "_LEGACY_ERROR_TEMP_2161",
785+
parameters = Map(
786+
"count" -> (BigInt(Int.MaxValue) - BigInt { Int.MinValue } + 1).toString,
787+
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
788+
checkErrorInExpression[SparkRuntimeException](
789+
new Sequence(Literal(0L), Literal(Long.MaxValue), Literal(1L)),
790+
errorClass = "_LEGACY_ERROR_TEMP_2161",
791+
parameters = Map(
792+
"count" -> (BigInt(Long.MaxValue) + 1).toString,
793+
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
794+
checkErrorInExpression[SparkRuntimeException](
795+
new Sequence(Literal(0L), Literal(Long.MinValue), Literal(-1L)),
796+
errorClass = "_LEGACY_ERROR_TEMP_2161",
797+
parameters = Map(
798+
"count" -> ((0 - BigInt(Long.MinValue)) + 1).toString(),
799+
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
800+
checkErrorInExpression[SparkRuntimeException](
801+
new Sequence(Literal(Long.MinValue), Literal(Long.MaxValue), Literal(1L)),
802+
errorClass = "_LEGACY_ERROR_TEMP_2161",
803+
parameters = Map(
804+
"count" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue } + 1).toString,
805+
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
806+
checkErrorInExpression[SparkRuntimeException](
807+
new Sequence(Literal(Long.MaxValue), Literal(Long.MinValue), Literal(-1L)),
808+
errorClass = "_LEGACY_ERROR_TEMP_2161",
809+
parameters = Map(
810+
"count" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue } + 1).toString,
811+
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
812+
checkErrorInExpression[SparkRuntimeException](
813+
new Sequence(Literal(Long.MaxValue), Literal(-1L), Literal(-1L)),
814+
errorClass = "_LEGACY_ERROR_TEMP_2161",
815+
parameters = Map(
816+
"count" -> (BigInt(Long.MaxValue) - BigInt { -1L } + 1).toString,
817+
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
818+
785819
// test sequence with one element (zero step or equal start and stop)
786820

787821
checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1))

0 commit comments

Comments
 (0)