diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 6dc334ceb52e..be119578d2c3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1278,7 +1278,7 @@ abstract class RDD[T: ClassTag]( def zipWithUniqueId(): RDD[(T, Long)] = withScope { val n = this.partitions.length.toLong this.mapPartitionsWithIndex { case (k, iter) => - iter.zipWithIndex.map { case (item, i) => + Utils.getIteratorZipWithIndex(iter, 0L).map { case (item, i) => (item, i * n + k) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index b5738b9a95c3..b0e5ba0865c6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -64,8 +64,7 @@ class ZippedWithIndexRDD[T: ClassTag](prev: RDD[T]) extends RDD[(T, Long)](prev) override def compute(splitIn: Partition, context: TaskContext): Iterator[(T, Long)] = { val split = splitIn.asInstanceOf[ZippedWithIndexRDDPartition] - firstParent[T].iterator(split.prev, context).zipWithIndex.map { x => - (x._1, split.startIndex + x._2) - } + val parentIter = firstParent[T].iterator(split.prev, context) + Utils.getIteratorZipWithIndex(parentIter, split.startIndex) } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index ef832756ce3b..ad3418d38c0d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1698,6 +1698,21 @@ private[spark] object Utils extends Logging { count } + /** + * Generate a zipWithIndex iterator, avoid index value overflowing problem + * in scala's zipWithIndex + */ + def getIteratorZipWithIndex[T](iterator: Iterator[T], startIndex: Long): Iterator[(T, Long)] = { + new Iterator[(T, Long)] { + var index: Long = startIndex - 1L + def hasNext: Boolean = iterator.hasNext + def next(): (T, Long) = { + index += 1L + (iterator.next(), index) + } + } + } + /** * Creates a symlink. * diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index bc28b2d9cb83..ca257c74635d 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -350,6 +350,13 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.getIteratorSize(iterator) === 5L) } + test("getIteratorZipWithIndex") { + val iterator = Utils.getIteratorZipWithIndex(Iterator(0, 1, 2), -1L + Int.MaxValue) + assert(iterator.toArray === Array( + (0, -1L + Int.MaxValue), (1, 0L + Int.MaxValue), (2, 1L + Int.MaxValue) + )) + } + test("doesDirectoryContainFilesNewerThan") { // create some temporary directories and files val parent: File = Utils.createTempDir()