diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index 53a9f92b82bc..5f803f141d8c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -21,12 +21,13 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.SparkContext._ +import org.apache.spark.rdd._ +import org.apache.spark.storage.StorageLevel + import org.apache.spark.graphx.impl.RoutingTablePartition import org.apache.spark.graphx.impl.ShippableVertexPartition import org.apache.spark.graphx.impl.VertexAttributeBlock import org.apache.spark.graphx.impl.VertexRDDImpl -import org.apache.spark.rdd._ -import org.apache.spark.storage.StorageLevel /** * Extends `RDD[(VertexId, VD)]` by ensuring that there is only one entry for each vertex and by @@ -54,8 +55,8 @@ import org.apache.spark.storage.StorageLevel * @tparam VD the vertex attribute associated with each vertex in the set. */ abstract class VertexRDD[VD]( - sc: SparkContext, - deps: Seq[Dependency[_]]) extends RDD[(VertexId, VD)](sc, deps) { + @transient sc: SparkContext, + @transient deps: Seq[Dependency[_]]) extends RDD[(VertexId, VD)](sc, deps) { implicit protected def vdTag: ClassTag[VD] @@ -255,7 +256,8 @@ abstract class VertexRDD[VD]( shipSrc: Boolean, shipDst: Boolean): RDD[(PartitionID, VertexAttributeBlock[VD])] /** Generates an RDD of vertex IDs suitable for shipping to the edge partitions. */ - private[graphx] def shipVertexIds(): RDD[(PartitionID, Array[VertexId])] + private[graphx] def shipVertexIds( + shipSrc: Boolean, shipDst: Boolean): RDD[(PartitionID, Array[VertexId])] } // end of VertexRDD diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 81182adbc638..44fdee10011a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -21,11 +21,12 @@ import scala.reflect.{classTag, ClassTag} import org.apache.spark.HashPartitioner import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl._ import org.apache.spark.graphx.util.BytecodeUtils -import org.apache.spark.rdd.{RDD, ShuffledRDD} -import org.apache.spark.storage.StorageLevel + /** * An implementation of [[org.apache.spark.graphx.Graph]] to support computation on graphs. @@ -222,13 +223,14 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( // For each vertex, replicate its attribute only to partitions where it is // in the relevant position in an edge. replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst) + val activeDirectionOpt = activeSetOpt.map(_._2) + val view = activeSetOpt match { case Some((activeSet, _)) => - replicatedVertexView.withActiveSet(activeSet) + replicatedVertexView.withActiveSet(activeSet, activeDirectionOpt) case None => replicatedVertexView } - val activeDirectionOpt = activeSetOpt.map(_._2) // Map and combine. val preAgg = view.edges.partitionsRDD.mapPartitions(_.flatMap { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index f79f9c7ec448..71ec829edcbc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -20,9 +20,10 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} import org.apache.spark.SparkContext._ -import org.apache.spark.graphx._ import org.apache.spark.rdd.RDD +import org.apache.spark.graphx._ + /** * Manages shipping vertex attributes to the edge partitions of an * [[org.apache.spark.graphx.EdgeRDD]]. Vertex attributes may be partially shipped to construct a @@ -85,12 +86,14 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( * vertex ids present in `actives`. This ships a vertex id to all edge partitions where it is * referenced, ignoring the attribute shipping level. */ - def withActiveSet(actives: VertexRDD[_]): ReplicatedVertexView[VD, ED] = { - val shippedActives = actives.shipVertexIds() + def withActiveSetPosition(actives: VertexRDD[_], useSrc: Boolean, useDst: Boolean): + ReplicatedVertexView[VD, ED] = { + val shippedActives = actives.shipVertexIds(useSrc, useDst) .setName("ReplicatedVertexView.withActiveSet - shippedActives (broadcast)") .partitionBy(edges.partitioner.get) - val newEdges = edges.withPartitionsRDD(edges.partitionsRDD.zipPartitions(shippedActives) { + val newEdges = edges.withPartitionsRDD(edges.partitionsRDD + .zipPartitions(shippedActives) { (ePartIter, shippedActivesIter) => ePartIter.map { case (pid, edgePartition) => (pid, edgePartition.withActiveSet(shippedActivesIter.flatMap(_._2.iterator))) @@ -99,6 +102,24 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( new ReplicatedVertexView(newEdges, hasSrcId, hasDstId) } + + def withActiveSet(actives: VertexRDD[_], edgeDir: Option[EdgeDirection]): + ReplicatedVertexView[VD, ED] = { + edgeDir match { + case Some(EdgeDirection.Both) => + withActiveSetPosition(actives, true, true) + case Some(EdgeDirection.Either) => + withActiveSetPosition(actives, true, true) + case Some(EdgeDirection.Out) => + withActiveSetPosition(actives, true, false) + case Some(EdgeDirection.In) => + withActiveSetPosition(actives, false, true) + case _ => + this + } + } + + /** * Return a new `ReplicatedVertexView` where vertex attributes in edge partition are updated using * `updates`. This ships a vertex attribute only to the edge partitions where it is in the diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala index 3f203c4eca48..12b60011ae5d 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala @@ -19,9 +19,10 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag +import org.apache.spark.util.collection.{BitSet, PrimitiveVector} + import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap -import org.apache.spark.util.collection.{BitSet, PrimitiveVector} /** Stores vertex attributes to ship to an edge partition. */ private[graphx] @@ -134,11 +135,12 @@ class ShippableVertexPartition[VD: ClassTag]( * contains the visible vertex ids from the current partition that are referenced in the edge * partition. */ - def shipVertexIds(): Iterator[(PartitionID, Array[VertexId])] = { + def shipVertexIds(shipSrc: Boolean, shipDst: Boolean): + Iterator[(PartitionID, Array[VertexId])] = { Iterator.tabulate(routingTable.numEdgePartitions) { pid => val vids = new PrimitiveVector[VertexId](routingTable.partitionSize(pid)) var i = 0 - routingTable.foreachWithinEdgePartition(pid, true, true) { vid => + routingTable.foreachWithinEdgePartition(pid, shipSrc, shipDst) { vid => if (isDefined(vid)) { vids += vid } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index d5accdfbf7e9..4abb305853e0 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -21,10 +21,11 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.SparkContext._ -import org.apache.spark.graphx._ import org.apache.spark.rdd._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.graphx._ + class VertexRDDImpl[VD] private[graphx] ( @transient val partitionsRDD: RDD[ShippableVertexPartition[VD]], val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) @@ -246,8 +247,9 @@ class VertexRDDImpl[VD] private[graphx] ( partitionsRDD.mapPartitions(_.flatMap(_.shipVertexAttributes(shipSrc, shipDst))) } - override private[graphx] def shipVertexIds(): RDD[(PartitionID, Array[VertexId])] = { - partitionsRDD.mapPartitions(_.flatMap(_.shipVertexIds())) + override private[graphx] def shipVertexIds(shipSrc: Boolean, shipDst: Boolean): + RDD[(PartitionID, Array[VertexId])] = { + partitionsRDD.mapPartitions(_.flatMap(_.shipVertexIds(shipSrc, shipDst))) } }