diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 9544475ff042..80a4f8408746 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.api.java import java.{lang => jl} import java.lang.{Iterable => JIterable} -import java.util.{Comparator, List => JList} +import java.util.{Comparator, Iterator => JIterator, List => JList} import scala.collection.JavaConverters._ import scala.language.implicitConversions @@ -34,7 +34,8 @@ import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.Partitioner._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction} +import org.apache.spark.api.java.function.{FlatMapFunction, Function => JFunction, + Function2 => JFunction2, PairFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.{OrderedRDDFunctions, RDD} import org.apache.spark.rdd.RDD.rddToPairRDDFunctions @@ -674,8 +675,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Pass each value in the key-value pair RDD through a flatMap function without changing the * keys; this also retains the original RDD's partitioning. */ - def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = { - def fn: (V) => Iterable[U] = (x: V) => f.call(x).asScala + def flatMapValues[U](f: FlatMapFunction[V, U]): JavaPairRDD[K, U] = { + def fn: (V) => Iterator[U] = (x: V) => f.call(x).asScala implicit val ctag: ClassTag[U] = fakeClassTag fromRDD(rdd.flatMapValues(fn)) } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index bf85fe0b4512..851fa2334501 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,7 +36,9 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.SnappyCompressionCodec.version") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.SnappyCompressionCodec.version"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.api.java.JavaPairRDD.flatMapValues"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues") ) // Exclude rules for 2.4.x diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 2ec907c8cfd5..c3c13df651cc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -34,7 +34,8 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaUtils, Optional} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} +import org.apache.spark.api.java.function.{FlatMapFunction, Function => JFunction, + Function2 => JFunction2} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ @@ -562,9 +563,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Return a new DStream by applying a flatmap function to the value of each key-value pairs in * 'this' DStream without changing the key. */ - def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairDStream[K, U] = { - import scala.collection.JavaConverters._ - def fn: (V) => Iterable[U] = (x: V) => f.apply(x).asScala + def flatMapValues[U](f: FlatMapFunction[V, U]): JavaPairDStream[K, U] = { + def fn: (V) => Iterator[U] = (x: V) => f.call(x).asScala implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] dstream.flatMapValues(fn) diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java index 90d1f8c5035b..b154f0e3ac45 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java @@ -841,7 +841,7 @@ public void testFlatMapValues() { JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream flatMapped = - pairStream.flatMapValues(in -> Arrays.asList(in + "1", in + "2")); + pairStream.flatMapValues(in -> Arrays.asList(in + "1", in + "2").iterator()); JavaTestUtils.attachTestOutputStream(flatMapped); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java index 6c86cacec827..c7cde5674f54 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java @@ -1355,7 +1355,7 @@ public void testFlatMapValues() { List out = new ArrayList<>(); out.add(in + "1"); out.add(in + "2"); - return out; + return out.iterator(); }); JavaTestUtils.attachTestOutputStream(flatMapped);