From 151a5efe37376161dc4513e3087661cc035bb684 Mon Sep 17 00:00:00 2001 From: QiangCai Date: Sun, 3 Jan 2016 21:47:58 +0800 Subject: [PATCH 1/4] [SPARK-12340][SQL]fix Int overflow in the SparkPlan.executeTake, RDD.take and AsyncRDDActions.takeAsync --- .../org/apache/spark/rdd/AsyncRDDActions.scala | 10 +++++----- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 8 ++++---- .../org/apache/spark/sql/execution/SparkPlan.scala | 8 ++++---- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 13 +++++++++++++ 4 files changed, 26 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 14f541f937b4c..6462ed4f8106e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -77,13 +77,13 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi This implementation is non-blocking, asynchronously handling the results of each job and triggering the next job using callbacks on futures. */ - def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] = + def continue(partsScanned: Long)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] = if (results.size >= num || partsScanned >= totalParts) { Future.successful(results.toSeq) } else { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate it @@ -99,7 +99,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } val left = num - results.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt val buf = new Array[Array[T]](p.size) self.context.setCallSite(callSite) @@ -111,11 +111,11 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi Unit) job.flatMap {_ => buf.foreach(results ++= _.take(num - results.size)) - continue(partsScanned + numPartsToTry) + continue(partsScanned + p.size) } } - new ComplexFutureAction[Seq[T]](continue(0)(_)) + new ComplexFutureAction[Seq[T]](continue(0L)(_)) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 9fe9d83a705b2..7b09662c88e48 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1291,11 +1291,11 @@ abstract class RDD[T: ClassTag]( } else { val buf = new ArrayBuffer[T] val totalParts = this.partitions.length - var partsScanned = 0 + var partsScanned = 0L while (buf.size < num && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate @@ -1310,11 +1310,11 @@ abstract class RDD[T: ClassTag]( } val left = num - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(num - buf.size)) - partsScanned += numPartsToTry + partsScanned += p.size } buf.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index f20f32aaced2e..21a6fba9078df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -165,11 +165,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length - var partsScanned = 0 + var partsScanned = 0L while (buf.size < n && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the first iteration, just try all partitions next. // Otherwise, interpolate the number of partitions we need to try, but overestimate it @@ -183,13 +183,13 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions val left = n - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt val sc = sqlContext.sparkContext val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(n - buf.size)) - partsScanned += numPartsToTry + partsScanned += p.size } buf.toArray diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 72845711adddd..5ac1fe16b154b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2028,6 +2028,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(false) :: Row(true) :: Nil) } +<<<<<<< 1cdc42d2b99edfec01066699a7620cca02b61f0e test("rollup") { checkAnswer( sql("select course, year, sum(earnings) from courseSales group by rollup(course, year)" + @@ -2067,4 +2068,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } } + + test("SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake"){ + val rdd = sqlContext.sparkContext.parallelize(1 to 3 , 3 ) + rdd.toDF("key").registerTempTable("spark12340") + checkAnswer( + sql("select key from spark12340 limit 2147483638"), + Row(1) :: Row(2) :: Row(3) :: Nil + ) + assert(rdd.take(2147483638).size === 3) + assert(rdd.takeAsync(2147483638).get.size === 3) + } + } From 3d340f730309f9a2930051caea0e516ef52b1d06 Mon Sep 17 00:00:00 2001 From: QiangCai Date: Tue, 5 Jan 2016 21:34:39 +0800 Subject: [PATCH 2/4] merge conflict --- sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5ac1fe16b154b..1b6afea29c1fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2028,7 +2028,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(false) :: Row(true) :: Nil) } -<<<<<<< 1cdc42d2b99edfec01066699a7620cca02b61f0e test("rollup") { checkAnswer( sql("select course, year, sum(earnings) from courseSales group by rollup(course, year)" + From ca18565f599a5f075d65425ec9fdcfa160c5df00 Mon Sep 17 00:00:00 2001 From: QiangCai Date: Wed, 6 Jan 2016 00:59:09 +0800 Subject: [PATCH 3/4] modify param to be small enough --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1b6afea29c1fb..e869cb8dbc86d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2076,7 +2076,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(1) :: Row(2) :: Row(3) :: Nil ) assert(rdd.take(2147483638).size === 3) - assert(rdd.takeAsync(2147483638).get.size === 3) + assert(rdd.takeAsync(3).get.size === 3) } } From e7577ee98630d1e53782e6f7dbc7979c1bc558a8 Mon Sep 17 00:00:00 2001 From: QiangCai Date: Wed, 6 Jan 2016 01:13:05 +0800 Subject: [PATCH 4/4] fix initial of result in AsyncRDDActions.takeAsync --- .../src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala | 2 +- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 6462ed4f8106e..c91e3bea406a5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -68,7 +68,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val localProperties = self.context.getLocalProperties // Cached thread pool to handle aggregation of subtasks. implicit val executionContext = AsyncRDDActions.futureExecutionContext - val results = new ArrayBuffer[T](num) + val results = new ArrayBuffer[T] val totalParts = self.partitions.length /* diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index e869cb8dbc86d..a0d33ce4603d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2068,7 +2068,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake"){ + test("SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake") { val rdd = sqlContext.sparkContext.parallelize(1 to 3 , 3 ) rdd.toDF("key").registerTempTable("spark12340") checkAnswer( @@ -2076,7 +2076,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(1) :: Row(2) :: Row(3) :: Nil ) assert(rdd.take(2147483638).size === 3) - assert(rdd.takeAsync(3).get.size === 3) + assert(rdd.takeAsync(2147483638).get.size === 3) } }