diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 545bc0e9e99e..0ec7f7bcc477 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -27,7 +27,7 @@ import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec import org.apache.spark.{FutureAction, Partition, SparkContext, TaskContext} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _} @@ -81,7 +81,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsWithIndex[R]( f: JFunction2[java.lang.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = - new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), + new JavaRDD(rdd.mapPartitionsWithIndex(((a, b) => f(a, asJavaIterator(b))), preservesPartitioning)(fakeClassTag))(fakeClassTag) /** @@ -185,6 +185,39 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2]) } + /** + * :: DeveloperApi :: + * Return a new RDD by applying a function to each partition of this RDD. This is a variant of + * mapPartitions that also passes the TaskContext into the closure. + * + * `preservesPartitioning` indicates whether the input function preserves the partitioner, which + * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. + */ + @DeveloperApi + def mapPartitionsWithContext[R]( + f: JFunction2[TaskContext, java.util.Iterator[T], java.util.Iterator[R]], + preservesPartitioning: Boolean): JavaRDD[R] = { + + new JavaRDD(rdd.mapPartitionsWithContext( + ((a, b) => f(a, asJavaIterator(b))), preservesPartitioning)(fakeClassTag))(fakeClassTag) + } + + /** + * :: DeveloperApi :: + * Return a new JavaPairRDD by applying a function to each partition of this RDD. This is a + * variant of mapPartitions that also passes the TaskContext into the closure. + * + * `preservesPartitioning` indicates whether the input function preserves the partitioner, which + * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. + */ + @DeveloperApi + def mapPartitionsToPairWithContext[K2, V2]( + f: JFunction2[TaskContext, java.util.Iterator[T], java.util.Iterator[(K2, V2)]], + preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = { + + JavaPairRDD.fromJavaRDD(mapPartitionsWithContext(f, preservesPartitioning)) + } + /** * Applies a function f to each partition of this RDD. */ diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index e1c13de04a0b..4e9f8eeeefdf 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -707,6 +707,86 @@ public Iterable call(Iterator iter) { Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); } + @Test + public void mapPartitionsWithContext() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + JavaRDD partitionSumsWithContext = rdd.mapPartitionsWithContext( + new Function2, Iterator>() { + @Override + public Iterator call(TaskContext context, + Iterator iter) throws Exception { + + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); + } + return Collections.singletonList(sum + "-partition-" + context.partitionId()).iterator(); + } + }, false); + Assert.assertEquals("[3-partition-0, 7-partition-1]", + partitionSumsWithContext.collect().toString()); + } + + @Test + public void mapPartitionsToPair() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + JavaPairRDD pairRdd = rdd.mapPartitionsToPair( + new PairFlatMapFunction, Integer, String>() { + @Override + public Iterable> call(Iterator iter) throws Exception { + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); + } + return Collections.singletonList(new Tuple2(sum, "a")); + } + } + ); + + Assert.assertEquals("[(3,a), (7,a)]", pairRdd.collect().toString()); + } + + @Test + public void mapPartitionsToPairWithContext() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + JavaPairRDD pairRdd = rdd.mapPartitionsToPairWithContext( + new Function2, Iterator>>() { + @Override + public Iterator> call(TaskContext context, + Iterator iter) throws Exception { + + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); + } + return Collections.singletonList( + new Tuple2(sum, "partition-" + context.partitionId())).iterator(); + } + }, false + ); + + Assert.assertEquals("[(3,partition-0), (7,partition-1)]", pairRdd.collect().toString()); + } + + @Test + public void mapPartitionsToDouble() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + JavaDoubleRDD pairRdd = rdd.mapPartitionsToDouble( + new DoubleFlatMapFunction>() { + @Override + public Iterable call(Iterator iter) throws Exception { + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); + } + return Collections.singletonList(Double.valueOf(sum)); + } + } + ); + + Assert.assertEquals("[3.0, 7.0]", pairRdd.collect().toString()); + } + @Test public void repartition() { // Shrinking number of partitions