Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
f.run {
// This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which
// is a cached thread pool.
val results = new ArrayBuffer[T](num)
val results = new ArrayBuffer[T]
val totalParts = self.partitions.length
var partsScanned = 0
var partsScanned = 0L
self.context.setCallSite(callSite)
while (results.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
var numPartsToTry = 1L
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
Expand All @@ -94,7 +94,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
}

val left = num - results.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt

val buf = new Array[Array[T]](p.size)
f.runJob(self,
Expand All @@ -104,7 +104,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
Unit)

buf.foreach(results ++= _.take(num - results.size))
partsScanned += numPartsToTry
partsScanned += p.size
}
results.toSeq
}(AsyncRDDActions.futureExecutionContext)
Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1291,11 +1291,11 @@ abstract class RDD[T: ClassTag](
} else {
val buf = new ArrayBuffer[T]
val totalParts = this.partitions.length
var partsScanned = 0
var partsScanned = 0L
while (buf.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
var numPartsToTry = 1L
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate
Expand All @@ -1310,11 +1310,11 @@ abstract class RDD[T: ClassTag](
}

val left = num - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)

res.foreach(buf ++= _.take(num - buf.size))
partsScanned += numPartsToTry
partsScanned += p.size
}

buf.toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ

val buf = new ArrayBuffer[InternalRow]
val totalParts = childRDD.partitions.length
var partsScanned = 0
var partsScanned = 0L
while (buf.size < n && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
var numPartsToTry = 1L
if (partsScanned > 0) {
// If we didn't find any rows after the first iteration, just try all partitions next.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
Expand All @@ -206,13 +206,13 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions

val left = n - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
val sc = sqlContext.sparkContext
val res =
sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p)

res.foreach(buf ++= _.take(n - buf.size))
partsScanned += numPartsToTry
partsScanned += p.size
}

buf.toArray
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2028,4 +2028,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Row(false) :: Row(true) :: Nil)
}

test("SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake") {
val rdd = sqlContext.sparkContext.parallelize(1 to 3 , 3 )
rdd.toDF("key").registerTempTable("spark12340")
checkAnswer(
sql("select key from spark12340 limit 2147483638"),
Row(1) :: Row(2) :: Row(3) :: Nil
)
assert(rdd.take(2147483638).size === 3)
assert(rdd.takeAsync(2147483638).get.size === 3)
}

}