diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a9b905b0d1a63..a13f71fd4de01 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -207,6 +207,21 @@ abstract class RDD[T: ClassTag]( } } + /** + * Get the number of partitions in this RDD + * + * {{{ + * scala> val rdd = sc.parallelize(1 to 4, 2) + * rdd: org.apache.spark.rdd.RDD[Int] = ParallelCollectionRDD[1] at parallelize at :12 + * + * scala> rdd.getNumPartitions + * res1: Int = 2 + * }}} + * + * @return The number of partitions in this RDD + */ + def getNumPartitions: Int = partitions.size + /** * Get the preferred locations of a partition (as hostnames), taking into account whether the * RDD is checkpointed. diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index c1b501a75c8b8..dba45996be7f9 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -34,6 +34,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("basic operations") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + assert(nums.getNumPartitions === 2) assert(nums.collect().toList === List(1, 2, 3, 4)) assert(nums.toLocalIterator.toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)