Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.Comparator
import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.spark.SparkException.internalError
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
Expand All @@ -40,7 +41,6 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SQLOpenHashSet
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String}

/**
Expand Down Expand Up @@ -3080,6 +3080,34 @@ case class Sequence(
}

object Sequence {
private def prettyName: String = "sequence"

def sequenceLength(start: Long, stop: Long, step: Long): Int = {
try {
val delta = Math.subtractExact(stop, start)
if (delta == Long.MinValue && step == -1L) {
// We must special-case division of Long.MinValue by -1 to catch potential unchecked
// overflow in next operation. Division does not have a builtin overflow check. We
// previously special-case div-by-zero.
throw new ArithmeticException("Long overflow (Long.MinValue / -1)")
}
val len = if (stop == start) 1L else Math.addExact(1L, (delta / step))
if (len > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw QueryExecutionErrors.createArrayWithElementsExceedLimitError(len)
}
len.toInt
} catch {
// We handle overflows in the previous try block by raising an appropriate exception.
case _: ArithmeticException =>
val safeLen =
BigInt(1) + (BigInt(stop) - BigInt(start)) / BigInt(step)
if (safeLen > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw QueryExecutionErrors.createArrayWithElementsExceedLimitError(safeLen)
}
throw internalError("Unreachable code reached.")
case e: Exception => throw e
}
}

private type LessThanOrEqualFn = (Any, Any) => Boolean

Expand Down Expand Up @@ -3451,13 +3479,7 @@ object Sequence {
|| (estimatedStep == num.zero && start == stop),
s"Illegal sequence boundaries: $start to $stop by $step")

val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / estimatedStep.toLong

require(
len <= MAX_ROUNDED_ARRAY_LENGTH,
s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")

len.toInt
sequenceLength(start.toLong, stop.toLong, estimatedStep.toLong)
}

private def genSequenceLengthCode(
Expand All @@ -3467,20 +3489,15 @@ object Sequence {
step: String,
estimatedStep: String,
len: String): String = {
val longLen = ctx.freshName("longLen")
val calcFn = classOf[Sequence].getName + ".sequenceLength"
s"""
|if (!(($estimatedStep > 0 && $start <= $stop) ||
| ($estimatedStep < 0 && $start >= $stop) ||
| ($estimatedStep == 0 && $start == $stop))) {
| throw new IllegalArgumentException(
| "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step);
|}
|long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $estimatedStep;
|if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) {
| throw new IllegalArgumentException(
| "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH");
|}
|int $len = (int) $longLen;
|int $len = $calcFn((long) $start, (long) $stop, (long) $estimatedStep);
""".stripMargin
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{outstandingZoneIds,
import org.apache.spark.sql.catalyst.util.IntervalUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.UTF8String

class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -769,10 +769,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper

// test sequence boundaries checking

checkExceptionInExpression[IllegalArgumentException](
new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
EmptyRow, s"Too long sequence: 4294967296. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")

checkExceptionInExpression[IllegalArgumentException](
new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries: 1 to 2 by 0")
checkExceptionInExpression[IllegalArgumentException](
Expand All @@ -782,6 +778,44 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkExceptionInExpression[IllegalArgumentException](
new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow, "boundaries: 1 to 2 by -1")

// SPARK-43393: test Sequence overflow checking
checkErrorInExpression[SparkRuntimeException](
new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
errorClass = "_LEGACY_ERROR_TEMP_2161",
parameters = Map(
"count" -> (BigInt(Int.MaxValue) - BigInt { Int.MinValue } + 1).toString,
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
checkErrorInExpression[SparkRuntimeException](
new Sequence(Literal(0L), Literal(Long.MaxValue), Literal(1L)),
errorClass = "_LEGACY_ERROR_TEMP_2161",
parameters = Map(
"count" -> (BigInt(Long.MaxValue) + 1).toString,
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
checkErrorInExpression[SparkRuntimeException](
new Sequence(Literal(0L), Literal(Long.MinValue), Literal(-1L)),
errorClass = "_LEGACY_ERROR_TEMP_2161",
parameters = Map(
"count" -> ((0 - BigInt(Long.MinValue)) + 1).toString(),
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
checkErrorInExpression[SparkRuntimeException](
new Sequence(Literal(Long.MinValue), Literal(Long.MaxValue), Literal(1L)),
errorClass = "_LEGACY_ERROR_TEMP_2161",
parameters = Map(
"count" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue } + 1).toString,
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
checkErrorInExpression[SparkRuntimeException](
new Sequence(Literal(Long.MaxValue), Literal(Long.MinValue), Literal(-1L)),
errorClass = "_LEGACY_ERROR_TEMP_2161",
parameters = Map(
"count" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue } + 1).toString,
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
checkErrorInExpression[SparkRuntimeException](
new Sequence(Literal(Long.MaxValue), Literal(-1L), Literal(-1L)),
errorClass = "_LEGACY_ERROR_TEMP_2161",
parameters = Map(
"count" -> (BigInt(Long.MaxValue) - BigInt { -1L } + 1).toString,
"maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))

// test sequence with one element (zero step or equal start and stop)

checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1))
Expand Down