diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala index 897c7ee12a436..404b7d80e69f9 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -19,7 +19,7 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} -import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext} +import org.apache.spark.{OneToOneDependency, HashPartitioner, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -45,8 +45,13 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( * [[PartitionID]]s in `partitionsRDD` correspond to the actual partitions and create a new * partitioner that allows co-partitioning with `partitionsRDD`. */ - override val partitioner = - partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD))) + override val partitioner = { + if (partitionsRDD.partitioner.isDefined) { + partitionsRDD.partitioner + } else { + Some(new HashPartitioner(partitions.size)) + } + } override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect()