From b35e9b922e893ab825bfd43370cc91852f6cc50a Mon Sep 17 00:00:00 2001 From: QiangCai Date: Wed, 6 Jan 2016 23:51:17 +0800 Subject: [PATCH] avoid Int overflow --- .../scala/org/apache/spark/rdd/AsyncRDDActions.scala | 10 +++++----- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 8 ++++---- .../org/apache/spark/sql/execution/SparkPlan.scala | 8 ++++---- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 11 +++++++++++ 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index d5e853613b05b..cdf604b519746 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -71,14 +71,14 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi f.run { // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which // is a cached thread pool. - val results = new ArrayBuffer[T](num) + val results = new ArrayBuffer[T] val totalParts = self.partitions.length - var partsScanned = 0 + var partsScanned = 0L self.context.setCallSite(callSite) while (results.size < num && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate it @@ -94,7 +94,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } val left = num - results.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt val buf = new Array[Array[T]](p.size) f.runJob(self, @@ -104,7 +104,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi Unit) buf.foreach(results ++= _.take(num - results.size)) - partsScanned += numPartsToTry + partsScanned += p.size } results.toSeq }(AsyncRDDActions.futureExecutionContext) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 9fe9d83a705b2..7b09662c88e48 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1291,11 +1291,11 @@ abstract class RDD[T: ClassTag]( } else { val buf = new ArrayBuffer[T] val totalParts = this.partitions.length - var partsScanned = 0 + var partsScanned = 0L while (buf.size < num && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate @@ -1310,11 +1310,11 @@ abstract class RDD[T: ClassTag]( } val left = num - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(num - buf.size)) - partsScanned += numPartsToTry + partsScanned += p.size } buf.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ec98f81041343..20f39c4c7fd0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -188,11 +188,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length - var partsScanned = 0 + var partsScanned = 0L while (buf.size < n && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the first iteration, just try all partitions next. // Otherwise, interpolate the number of partitions we need to try, but overestimate it @@ -206,13 +206,13 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions val left = n - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt val sc = sqlContext.sparkContext val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(n - buf.size)) - partsScanned += numPartsToTry + partsScanned += p.size } buf.toArray diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bb82b562aaaa2..582d6ce7531b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2028,4 +2028,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(false) :: Row(true) :: Nil) } + test("SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake") { + val rdd = sqlContext.sparkContext.parallelize(1 to 3 , 3 ) + rdd.toDF("key").registerTempTable("spark12340") + checkAnswer( + sql("select key from spark12340 limit 2147483638"), + Row(1) :: Row(2) :: Row(3) :: Nil + ) + assert(rdd.take(2147483638).size === 3) + assert(rdd.takeAsync(2147483638).get.size === 3) + } + }