Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
val localProperties = self.context.getLocalProperties
// Cached thread pool to handle aggregation of subtasks.
implicit val executionContext = AsyncRDDActions.futureExecutionContext
val results = new ArrayBuffer[T](num)
val results = new ArrayBuffer[T]
val totalParts = self.partitions.length

/*
Expand All @@ -77,13 +77,13 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
This implementation is non-blocking, asynchronously handling the
results of each job and triggering the next job using callbacks on futures.
*/
def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] =
def continue(partsScanned: Long)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] =
if (results.size >= num || partsScanned >= totalParts) {
Future.successful(results.toSeq)
} else {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
var numPartsToTry = 1L
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
Expand All @@ -99,7 +99,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
}

val left = num - results.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt

val buf = new Array[Array[T]](p.size)
self.context.setCallSite(callSite)
Expand All @@ -111,11 +111,11 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
Unit)
job.flatMap {_ =>
buf.foreach(results ++= _.take(num - results.size))
continue(partsScanned + numPartsToTry)
continue(partsScanned + p.size)
}
}

new ComplexFutureAction[Seq[T]](continue(0)(_))
new ComplexFutureAction[Seq[T]](continue(0L)(_))
}

/**
Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1291,11 +1291,11 @@ abstract class RDD[T: ClassTag](
} else {
val buf = new ArrayBuffer[T]
val totalParts = this.partitions.length
var partsScanned = 0
var partsScanned = 0L
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this change necessary? When can partsScanned go above 2B?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you're right. partScanned cannot exceed the value of totalParts.
I'll return it to Int.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a legit problem here. Imagine totalParts is close to Int.MaxValue, and imagine partsScanned is close to totalParts. Adding p.size to it below could cause it to roll over. I think this change is needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's never possible -- if we have anywhere near 2B partitions, the scheduler won't be fast enough to schedule them. As a matter of fact, if we have anywhere larger than a few millions, the scheduler will likely crash.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point, in practice this all but certainly won't happen. Note that this patch was already committed to master making this a Long. It doesn't hurt and is very very theoretically more correct locally. I suppose I don't think it's worth updating again, but I do not feel strongly about it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to change it back since it is so little work, so this does not start a trend to change all ints to longs for no reason. Note that this also raise questions about why this value can be greater than int.max when somebody reads this code in the future.

Also @srowen even if totalParts is close to int.max, I don't think partsScanned can be greater than int.max because we never scan more parts than the number of parts available.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok you were referring to partsScanned + numPartsToTry - we should just cast that to long to minimize the impact.

while (buf.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
var numPartsToTry = 1L
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate
Expand All @@ -1310,11 +1310,11 @@ abstract class RDD[T: ClassTag](
}

val left = num - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)

res.foreach(buf ++= _.take(num - buf.size))
partsScanned += numPartsToTry
partsScanned += p.size
}

buf.toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ

val buf = new ArrayBuffer[InternalRow]
val totalParts = childRDD.partitions.length
var partsScanned = 0
var partsScanned = 0L
while (buf.size < n && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
var numPartsToTry = 1L
if (partsScanned > 0) {
// If we didn't find any rows after the first iteration, just try all partitions next.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
Expand All @@ -183,13 +183,13 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions

val left = n - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
val sc = sqlContext.sparkContext
val res =
sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p)

res.foreach(buf ++= _.take(n - buf.size))
partsScanned += numPartsToTry
partsScanned += p.size
}

buf.toArray
Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2067,4 +2067,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}
}

test("SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake") {
val rdd = sqlContext.sparkContext.parallelize(1 to 3 , 3 )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also remove the extra space before comma here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the sql part i'd just move this into the existing limit test case, and add a line of comment explaining this.

also you should explain in the comment why 2147483638 is chosen as a value.

rdd.toDF("key").registerTempTable("spark12340")
checkAnswer(
sql("select key from spark12340 limit 2147483638"),
Row(1) :: Row(2) :: Row(3) :: Nil
)
assert(rdd.take(2147483638).size === 3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should have a unit test in RDDSuite for the rdd tests, not in SQLQuerySuite.

assert(rdd.takeAsync(2147483638).get.size === 3)
}

}