From 673c29b2166e002d97b914ef8f8316df71fc8be7 Mon Sep 17 00:00:00 2001 From: codlife <1004910847@qq.com> Date: Sat, 10 Sep 2016 10:02:21 +0800 Subject: [PATCH 1/4] solve spark-17447 --- core/src/main/scala/org/apache/spark/Partitioner.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 98c3abe93b55..54596d0650a5 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -55,14 +55,16 @@ object Partitioner { * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD. */ def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { - val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.length).reverse - for (r <- bySize if r.partitioner.isDefined && r.partitioner.get.numPartitions > 0) { - return r.partitioner.get + val rdds = Seq(rdd) ++ others + + val filteredRdds = rdds.filter( _.partitioner.exists(_.numPartitions > 0 )) + if(filteredRdds.nonEmpty) { + return filteredRdds.maxBy( _.partitions.length).partitioner.get } if (rdd.context.conf.contains("spark.default.parallelism")) { new HashPartitioner(rdd.context.defaultParallelism) } else { - new HashPartitioner(bySize.head.partitions.length) + new HashPartitioner(rdds.map(_.partitions.length).max) } } } From a4609059350af3ebeb68e5acdfc99daf424a817a Mon Sep 17 00:00:00 2001 From: codlife <1004910847@qq.com> Date: Sat, 10 Sep 2016 10:26:46 +0800 Subject: [PATCH 2/4] Update Partitioner.scala --- core/src/main/scala/org/apache/spark/Partitioner.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 54596d0650a5..9fb1c4c05551 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -56,7 +56,6 @@ object Partitioner { */ def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { val rdds = Seq(rdd) ++ others - val filteredRdds = rdds.filter( _.partitioner.exists(_.numPartitions > 0 )) if(filteredRdds.nonEmpty) { return filteredRdds.maxBy( _.partitions.length).partitioner.get From 8ddc442fc40f71d85fcaef8e4a721f6b31a5ea5c Mon Sep 17 00:00:00 2001 From: codlife <1004910847@qq.com> Date: Sat, 10 Sep 2016 20:33:19 +0800 Subject: [PATCH 3/4] fix code style --- .../scala/org/apache/spark/Partitioner.scala | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 9fb1c4c05551..31b7d89e35a4 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -55,15 +55,16 @@ object Partitioner { * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD. */ def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { - val rdds = Seq(rdd) ++ others - val filteredRdds = rdds.filter( _.partitioner.exists(_.numPartitions > 0 )) - if(filteredRdds.nonEmpty) { - return filteredRdds.maxBy( _.partitions.length).partitioner.get - } - if (rdd.context.conf.contains("spark.default.parallelism")) { - new HashPartitioner(rdd.context.defaultParallelism) + val rdds = (Seq(rdd) ++ others) + val hashPartitioner = rdds.filter(_.partitioner.exists(_.numPartitions > 0)) + if (hashPartitioner.nonEmpty) { + hashPartitioner.maxBy(_.partitions.length).partitioner.get } else { - new HashPartitioner(rdds.map(_.partitions.length).max) + if (rdd.context.conf.contains("spark.default.parallelism")) { + new HashPartitioner(rdd.context.defaultParallelism) + } else { + new HashPartitioner(rdds.map(_.partitions.length).max) + } } } } From f5d1e24d38f4a24f2ebc29214eb1a331846a0b1b Mon Sep 17 00:00:00 2001 From: codlife <1004910847@qq.com> Date: Sat, 10 Sep 2016 23:21:44 +0800 Subject: [PATCH 4/4] Update Partitioner.scala --- core/src/main/scala/org/apache/spark/Partitioner.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 31b7d89e35a4..93dfbc0e6ed6 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -56,9 +56,9 @@ object Partitioner { */ def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { val rdds = (Seq(rdd) ++ others) - val hashPartitioner = rdds.filter(_.partitioner.exists(_.numPartitions > 0)) - if (hashPartitioner.nonEmpty) { - hashPartitioner.maxBy(_.partitions.length).partitioner.get + val hasPartitioner = rdds.filter(_.partitioner.exists(_.numPartitions > 0)) + if (hasPartitioner.nonEmpty) { + hasPartitioner.maxBy(_.partitions.length).partitioner.get } else { if (rdd.context.conf.contains("spark.default.parallelism")) { new HashPartitioner(rdd.context.defaultParallelism)