From 4a566dc86624ac3f6dfa747d344c86e4be44adc2 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 13 Aug 2014 19:33:47 -0700 Subject: [PATCH 1/7] Optimizations for mapReduceTriplets and EdgePartition 1. EdgePartition now stores local vertex ids instead of global ids. This avoids hash lookups when looking up vertex attributes and aggregating messages. 2. Internal iterators in mapReduceTriplets are inlined into a while loop. --- .../spark/graphx/impl/EdgePartition.scala | 262 +++++++++++++----- .../graphx/impl/EdgePartitionBuilder.scala | 95 ++++++- .../graphx/impl/EdgeTripletIterator.scala | 39 +-- .../apache/spark/graphx/impl/GraphImpl.scala | 46 +-- .../graphx/impl/RoutingTablePartition.scala | 8 +- .../org/apache/spark/graphx/GraphSuite.scala | 4 +- .../graphx/impl/EdgePartitionSuite.scala | 31 +-- 7 files changed, 310 insertions(+), 175 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index a5c9cd1f8b4e6..52661aa5f1d3c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -21,6 +21,7 @@ import scala.reflect.{classTag, ClassTag} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.util.collection.BitSet /** * A collection of edges stored in columnar format, along with any vertex attributes referenced. The @@ -30,54 +31,76 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap * @tparam ED the edge attribute type * @tparam VD the vertex attribute type * - * @param srcIds the source vertex id of each edge - * @param dstIds the destination vertex id of each edge + * @param localSrcIds the local source vertex id of each edge as an index into `local2global` and + * `vertexAttrs` + * @param localDstIds the local destination vertex id of each edge as an index into `local2global` + * and `vertexAttrs` * @param data the attribute associated with each edge - * @param index a clustered index on source vertex id - * @param vertices a map from referenced vertex ids to their corresponding attributes. Must - * contain all vertex ids from `srcIds` and `dstIds`, though not necessarily valid attributes for - * those vertex ids. The mask is not used. + * @param index a clustered index on source vertex id as a map from each global source vertex id to + * the offset in the edge arrays where the cluster for that vertex id begins + * @param global2local a map from referenced vertex ids to local ids which index into vertexAttrs + * @param local2global an array of global vertex ids where the offsets are local vertex ids + * @param vertexAttrs an array of vertex attributes where the offsets are local vertex ids * @param activeSet an optional active vertex set for filtering computation on the edges */ private[graphx] class EdgePartition[ @specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassTag, VD: ClassTag]( - val srcIds: Array[VertexId] = null, - val dstIds: Array[VertexId] = null, + val localSrcIds: Array[Int] = null, + val localDstIds: Array[Int] = null, val data: Array[ED] = null, val index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int] = null, - val vertices: VertexPartition[VD] = null, + val global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int] = null, + val local2global: Array[VertexId] = null, + val vertexAttrs: Array[VD] = null, val activeSet: Option[VertexSet] = None ) extends Serializable { /** Return a new `EdgePartition` with the specified edge data. */ - def withData[ED2: ClassTag](data_ : Array[ED2]): EdgePartition[ED2, VD] = { - new EdgePartition(srcIds, dstIds, data_, index, vertices, activeSet) - } - - /** Return a new `EdgePartition` with the specified vertex partition. */ - def withVertices[VD2: ClassTag]( - vertices_ : VertexPartition[VD2]): EdgePartition[ED, VD2] = { - new EdgePartition(srcIds, dstIds, data, index, vertices_, activeSet) + def withData[ED2: ClassTag](data: Array[ED2]): EdgePartition[ED2, VD] = { + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet) } /** Return a new `EdgePartition` with the specified active set, provided as an iterator. */ def withActiveSet(iter: Iterator[VertexId]): EdgePartition[ED, VD] = { - val newActiveSet = new VertexSet - iter.foreach(newActiveSet.add(_)) - new EdgePartition(srcIds, dstIds, data, index, vertices, Some(newActiveSet)) + val activeSet = new VertexSet + iter.foreach(activeSet.add(_)) + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, + Some(activeSet)) } /** Return a new `EdgePartition` with the specified active set. */ - def withActiveSet(activeSet_ : Option[VertexSet]): EdgePartition[ED, VD] = { - new EdgePartition(srcIds, dstIds, data, index, vertices, activeSet_) + def withActiveSet(activeSet: Option[VertexSet]): EdgePartition[ED, VD] = { + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet) } /** Return a new `EdgePartition` with updates to vertex attributes specified in `iter`. */ def updateVertices(iter: Iterator[(VertexId, VD)]): EdgePartition[ED, VD] = { - this.withVertices(vertices.innerJoinKeepLeft(iter)) + val newVertexAttrs = new Array[VD](vertexAttrs.length) + System.arraycopy(vertexAttrs, 0, newVertexAttrs, 0, vertexAttrs.length) + iter.foreach { kv => + newVertexAttrs(global2local(kv._1)) = kv._2 + } + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs, + activeSet) + } + + /** Return a new `EdgePartition` without any locally cached vertex attributes. */ + def clearVertices[VD2: ClassTag](): EdgePartition[ED, VD2] = { + val newVertexAttrs = new Array[VD2](vertexAttrs.length) + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs, + activeSet) } + def srcIds(i: Int): VertexId = local2global(localSrcIds(i)) + + def dstIds(i: Int): VertexId = local2global(localDstIds(i)) + /** Look up vid in activeSet, throwing an exception if it is None. */ def isActive(vid: VertexId): Boolean = { activeSet.get.contains(vid) @@ -92,11 +115,19 @@ class EdgePartition[ * @return a new edge partition with all edges reversed. */ def reverse: EdgePartition[ED, VD] = { - val builder = new EdgePartitionBuilder(size)(classTag[ED], classTag[VD]) - for (e <- iterator) { - builder.add(e.dstId, e.srcId, e.attr) + val builder = new VertexPreservingEdgePartitionBuilder( + global2local, local2global, vertexAttrs, size)(classTag[ED], classTag[VD]) + var i = 0 + while (i < size) { + val localSrcId = localSrcIds(i) + val localDstId = localDstIds(i) + val srcId = local2global(localSrcId) + val dstId = local2global(localDstId) + val attr = data(i) + builder.add(dstId, srcId, localDstId, localSrcId, attr) + i += 1 } - builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) + builder.toEdgePartition.withActiveSet(activeSet) } /** @@ -157,13 +188,25 @@ class EdgePartition[ def filter( epred: EdgeTriplet[VD, ED] => Boolean, vpred: (VertexId, VD) => Boolean): EdgePartition[ED, VD] = { - val filtered = tripletIterator().filter(et => - vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et)) - val builder = new EdgePartitionBuilder[ED, VD] - for (e <- filtered) { - builder.add(e.srcId, e.dstId, e.attr) + val builder = new VertexPreservingEdgePartitionBuilder[ED, VD]( + global2local, local2global, vertexAttrs) + var i = 0 + while (i < size) { + // The user sees the EdgeTriplet, so we can't reuse it and must create one per edge. + val localSrcId = localSrcIds(i) + val localDstId = localDstIds(i) + val et = new EdgeTriplet[VD, ED] + et.srcId = local2global(localSrcId) + et.dstId = local2global(localDstId) + et.srcAttr = vertexAttrs(localSrcId) + et.dstAttr = vertexAttrs(localDstId) + et.attr = data(i) + if (vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et)) { + builder.add(et.srcId, et.dstId, localSrcId, localDstId, et.attr) + } + i += 1 } - builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) + builder.toEdgePartition.withActiveSet(activeSet) } /** @@ -183,7 +226,8 @@ class EdgePartition[ * @return a new edge partition without duplicate edges */ def groupEdges(merge: (ED, ED) => ED): EdgePartition[ED, VD] = { - val builder = new EdgePartitionBuilder[ED, VD] + val builder = new VertexPreservingEdgePartitionBuilder[ED, VD]( + global2local, local2global, vertexAttrs) var currSrcId: VertexId = null.asInstanceOf[VertexId] var currDstId: VertexId = null.asInstanceOf[VertexId] var currAttr: ED = null.asInstanceOf[ED] @@ -193,7 +237,7 @@ class EdgePartition[ currAttr = merge(currAttr, data(i)) } else { if (i > 0) { - builder.add(currSrcId, currDstId, currAttr) + builder.add(currSrcId, currDstId, localSrcIds(i - 1), localDstIds(i - 1), currAttr) } currSrcId = srcIds(i) currDstId = dstIds(i) @@ -202,9 +246,9 @@ class EdgePartition[ i += 1 } if (size > 0) { - builder.add(currSrcId, currDstId, currAttr) + builder.add(currSrcId, currDstId, localSrcIds(i - 1), localDstIds(i - 1), currAttr) } - builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) + builder.toEdgePartition.withActiveSet(activeSet) } /** @@ -220,7 +264,8 @@ class EdgePartition[ def innerJoin[ED2: ClassTag, ED3: ClassTag] (other: EdgePartition[ED2, _]) (f: (VertexId, VertexId, ED, ED2) => ED3): EdgePartition[ED3, VD] = { - val builder = new EdgePartitionBuilder[ED3, VD] + val builder = new VertexPreservingEdgePartitionBuilder[ED3, VD]( + global2local, local2global, vertexAttrs) var i = 0 var j = 0 // For i = index of each edge in `this`... @@ -233,12 +278,13 @@ class EdgePartition[ while (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) < dstId) { j += 1 } if (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) == dstId) { // ... run `f` on the matching edge - builder.add(srcId, dstId, f(srcId, dstId, this.data(i), other.data(j))) + builder.add(srcId, dstId, localSrcIds(i), localDstIds(i), + f(srcId, dstId, this.data(i), other.data(j))) } } i += 1 } - builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) + builder.toEdgePartition.withActiveSet(activeSet) } /** @@ -246,7 +292,7 @@ class EdgePartition[ * * @return size of the partition */ - val size: Int = srcIds.size + val size: Int = localSrcIds.size /** The number of unique source vertices in the partition. */ def indexSize: Int = index.size @@ -285,50 +331,116 @@ class EdgePartition[ } /** - * Upgrade the given edge iterator into a triplet iterator. + * Send messages along edges and aggregate them at the receiving vertices. Implemented by scanning + * all edges sequentially and filtering them with `idPred`. + * + * @param mapFunc the edge map function which generates messages to neighboring vertices + * @param reduceFunc the combiner applied to messages destined to the same vertex + * @param mapUsesSrcAttr whether or not `mapFunc` uses the edge's source vertex attribute + * @param mapUsesDstAttr whether or not `mapFunc` uses the edge's destination vertex attribute + * @param idPred a predicate to filter edges based on their source and destination vertex ids * - * Be careful not to keep references to the objects from this iterator. To improve GC performance - * the same object is re-used in `next()`. + * @return iterator aggregated messages keyed by the receiving vertex id */ - def upgradeIterator( - edgeIter: Iterator[Edge[ED]], includeSrc: Boolean = true, includeDst: Boolean = true) - : Iterator[EdgeTriplet[VD, ED]] = { - new ReusingEdgeTripletIterator(edgeIter, this, includeSrc, includeDst) + def mapReduceTriplets[A: ClassTag]( + mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], + reduceFunc: (A, A) => A, + mapUsesSrcAttr: Boolean, + mapUsesDstAttr: Boolean, + idPred: (VertexId, VertexId) => Boolean): Iterator[(VertexId, A)] = { + val aggregates = new Array[A](vertexAttrs.length) + val bitset = new BitSet(vertexAttrs.length) + + var edge = new EdgeTriplet[VD, ED] + var i = 0 + while (i < size) { + val localSrcId = localSrcIds(i) + val srcId = local2global(localSrcId) + val localDstId = localDstIds(i) + val dstId = local2global(localDstId) + if (idPred(srcId, dstId)) { + edge.srcId = srcId + edge.dstId = dstId + edge.attr = data(i) + if (mapUsesSrcAttr) { edge.srcAttr = vertexAttrs(localSrcId) } + if (mapUsesDstAttr) { edge.dstAttr = vertexAttrs(localDstId) } + + mapFunc(edge).foreach { kv => + val globalId = kv._1 + val msg = kv._2 + val localId = if (globalId == srcId) localSrcId else localDstId + if (bitset.get(localId)) { + aggregates(localId) = reduceFunc(aggregates(localId), msg) + } else { + aggregates(localId) = msg + bitset.set(localId) + } + } + } + i += 1 + } + + bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) } } /** - * Get an iterator over the edges in this partition whose source vertex ids match srcIdPred. The - * iterator is generated using an index scan, so it is efficient at skipping edges that don't - * match srcIdPred. + * Send messages along edges and aggregate them at the receiving vertices. Implemented by + * filtering the source vertex index with `srcIdPred`, then scanning edge clusters and filtering + * with `dstIdPred`. Both `srcIdPred` and `dstIdPred` must match for an edge to run. * - * Be careful not to keep references to the objects from this iterator. To improve GC performance - * the same object is re-used in `next()`. - */ - def indexIterator(srcIdPred: VertexId => Boolean): Iterator[Edge[ED]] = - index.iterator.filter(kv => srcIdPred(kv._1)).flatMap(Function.tupled(clusterIterator)) - - /** - * Get an iterator over the cluster of edges in this partition with source vertex id `srcId`. The - * cluster must start at position `index`. + * @param mapFunc the edge map function which generates messages to neighboring vertices + * @param reduceFunc the combiner applied to messages destined to the same vertex + * @param mapUsesSrcAttr whether or not `mapFunc` uses the edge's source vertex attribute + * @param mapUsesDstAttr whether or not `mapFunc` uses the edge's destination vertex attribute + * @param srcIdPred a predicate to filter edges based on their source vertex id + * @param dstIdPred a predicate to filter edges based on their destination vertex id * - * Be careful not to keep references to the objects from this iterator. To improve GC performance - * the same object is re-used in `next()`. + * @return iterator aggregated messages keyed by the receiving vertex id */ - private def clusterIterator(srcId: VertexId, index: Int) = new Iterator[Edge[ED]] { - private[this] val edge = new Edge[ED] - private[this] var pos = index + def mapReduceTripletsWithIndex[A: ClassTag]( + mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], + reduceFunc: (A, A) => A, + mapUsesSrcAttr: Boolean, + mapUsesDstAttr: Boolean, + srcIdPred: VertexId => Boolean, + dstIdPred: VertexId => Boolean): Iterator[(VertexId, A)] = { + val aggregates = new Array[A](vertexAttrs.length) + val bitset = new BitSet(vertexAttrs.length) - override def hasNext: Boolean = { - pos >= 0 && pos < EdgePartition.this.size && srcIds(pos) == srcId - } + var edge = new EdgeTriplet[VD, ED] + index.iterator.foreach { cluster => + val clusterSrcId = cluster._1 + val clusterPos = cluster._2 + val clusterLocalSrcId = localSrcIds(clusterPos) + if (srcIdPred(clusterSrcId)) { + var pos = clusterPos + edge.srcId = clusterSrcId + if (mapUsesSrcAttr) { edge.srcAttr = vertexAttrs(clusterLocalSrcId) } + while (pos < size && localSrcIds(pos) == clusterLocalSrcId) { + val localDstId = localDstIds(pos) + val dstId = local2global(localDstId) + if (dstIdPred(dstId)) { + edge.dstId = dstId + edge.attr = data(pos) + if (mapUsesDstAttr) { edge.dstAttr = vertexAttrs(localDstId) } - override def next(): Edge[ED] = { - assert(srcIds(pos) == srcId) - edge.srcId = srcIds(pos) - edge.dstId = dstIds(pos) - edge.attr = data(pos) - pos += 1 - edge + mapFunc(edge).foreach { kv => + val globalId = kv._1 + val msg = kv._2 + val localId = if (globalId == clusterSrcId) clusterLocalSrcId else localDstId + if (bitset.get(localId)) { + aggregates(localId) = reduceFunc(aggregates(localId), msg) + } else { + aggregates(localId) = msg + bitset.set(localId) + } + } + } + pos += 1 + } + } } + + bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 4520beb991515..675247d1686a9 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -25,6 +25,7 @@ import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveVector} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +/** Constructs an EdgePartition from scratch. */ private[graphx] class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag]( size: Int = 64) { @@ -38,19 +39,76 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla def toEdgePartition: EdgePartition[ED, VD] = { val edgeArray = edges.trim().array Sorting.quickSort(edgeArray)(Edge.lexicographicOrdering) - val srcIds = new Array[VertexId](edgeArray.size) - val dstIds = new Array[VertexId](edgeArray.size) + val localSrcIds = new Array[Int](edgeArray.size) + val localDstIds = new Array[Int](edgeArray.size) + val data = new Array[ED](edgeArray.size) + val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int] + val global2local = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int] + val local2global = new PrimitiveVector[VertexId] + var vertexAttrs = Array.empty[VD] + // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and + // adding them to the index. Also populate a map from vertex id to a sequential local offset. + if (edgeArray.length > 0) { + index.update(edgeArray(0).srcId, 0) + var currSrcId: VertexId = edgeArray(0).srcId + var currLocalId = -1 + var i = 0 + while (i < edgeArray.size) { + val srcId = edgeArray(i).srcId + val dstId = edgeArray(i).dstId + localSrcIds(i) = global2local.changeValue(srcId, + { currLocalId += 1; local2global += srcId; currLocalId }, identity) + localDstIds(i) = global2local.changeValue(dstId, + { currLocalId += 1; local2global += dstId; currLocalId }, identity) + data(i) = edgeArray(i).attr + if (srcId != currSrcId) { + currSrcId = srcId + index.update(currSrcId, i) + } + + i += 1 + } + vertexAttrs = new Array[VD](currLocalId + 1) + } + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global.trim().array, vertexAttrs) + } +} + +/** + * Constructs an EdgePartition from an existing EdgePartition with the same vertex set. This enables + * reuse of the local vertex ids. + */ +private[graphx] +class VertexPreservingEdgePartitionBuilder[ + @specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag]( + global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int], + local2global: Array[VertexId], + vertexAttrs: Array[VD], + size: Int = 64) { + var edges = new PrimitiveVector[EdgeWithLocalIds[ED]](size) + + /** Add a new edge to the partition. */ + def add(src: VertexId, dst: VertexId, localSrc: Int, localDst: Int, d: ED) { + edges += EdgeWithLocalIds(src, dst, localSrc, localDst, d) + } + + def toEdgePartition: EdgePartition[ED, VD] = { + val edgeArray = edges.trim().array + Sorting.quickSort(edgeArray)(EdgeWithLocalIds.lexicographicOrdering) + val localSrcIds = new Array[Int](edgeArray.size) + val localDstIds = new Array[Int](edgeArray.size) val data = new Array[ED](edgeArray.size) val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int] // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and // adding them to the index if (edgeArray.length > 0) { - index.update(srcIds(0), 0) - var currSrcId: VertexId = srcIds(0) + index.update(edgeArray(0).srcId, 0) + var currSrcId: VertexId = edgeArray(0).srcId var i = 0 while (i < edgeArray.size) { - srcIds(i) = edgeArray(i).srcId - dstIds(i) = edgeArray(i).dstId + localSrcIds(i) = edgeArray(i).localSrcId + localDstIds(i) = edgeArray(i).localDstId data(i) = edgeArray(i).attr if (edgeArray(i).srcId != currSrcId) { currSrcId = edgeArray(i).srcId @@ -60,13 +118,24 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla } } - // Create and populate a VertexPartition with vids from the edges, but no attributes - val vidsIter = srcIds.iterator ++ dstIds.iterator - val vertexIds = new OpenHashSet[VertexId] - vidsIter.foreach(vid => vertexIds.add(vid)) - val vertices = new VertexPartition( - vertexIds, new Array[VD](vertexIds.capacity), vertexIds.getBitSet) + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs) + } +} - new EdgePartition(srcIds, dstIds, data, index, vertices) +private[graphx] case class EdgeWithLocalIds[@specialized ED]( + srcId: VertexId, dstId: VertexId, localSrcId: Int, localDstId: Int, attr: ED) + +private[graphx] object EdgeWithLocalIds { + implicit def lexicographicOrdering[ED] = new Ordering[EdgeWithLocalIds[ED]] { + override def compare(a: EdgeWithLocalIds[ED], b: EdgeWithLocalIds[ED]): Int = { + if (a.srcId == b.srcId) { + if (a.dstId == b.dstId) 0 + else if (a.dstId < b.dstId) -1 + else 1 + } else if (a.srcId < b.srcId) -1 + else 1 + } } + } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala index 56f79a7097fce..a8f829ed20a34 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala @@ -40,45 +40,18 @@ class EdgeTripletIterator[VD: ClassTag, ED: ClassTag]( override def next() = { val triplet = new EdgeTriplet[VD, ED] - triplet.srcId = edgePartition.srcIds(pos) + val localSrcId = edgePartition.localSrcIds(pos) + val localDstId = edgePartition.localDstIds(pos) + triplet.srcId = edgePartition.local2global(localSrcId) + triplet.dstId = edgePartition.local2global(localDstId) if (includeSrc) { - triplet.srcAttr = edgePartition.vertices(triplet.srcId) + triplet.srcAttr = edgePartition.vertexAttrs(localSrcId) } - triplet.dstId = edgePartition.dstIds(pos) if (includeDst) { - triplet.dstAttr = edgePartition.vertices(triplet.dstId) + triplet.dstAttr = edgePartition.vertexAttrs(localDstId) } triplet.attr = edgePartition.data(pos) pos += 1 triplet } } - -/** - * An Iterator type for internal use that reuses EdgeTriplet objects. This could be an anonymous - * class in EdgePartition.upgradeIterator, but we name it here explicitly so it is easier to debug / - * profile. - */ -private[impl] -class ReusingEdgeTripletIterator[VD: ClassTag, ED: ClassTag]( - val edgeIter: Iterator[Edge[ED]], - val edgePartition: EdgePartition[ED, VD], - val includeSrc: Boolean, - val includeDst: Boolean) - extends Iterator[EdgeTriplet[VD, ED]] { - - private val triplet = new EdgeTriplet[VD, ED] - - override def hasNext = edgeIter.hasNext - - override def next() = { - triplet.set(edgeIter.next()) - if (includeSrc) { - triplet.srcAttr = edgePartition.vertices(triplet.srcId) - } - if (includeDst) { - triplet.dstAttr = edgePartition.vertices(triplet.dstId) - } - triplet - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 33f35cfb69a26..1188e2ad91821 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -23,7 +23,6 @@ import org.apache.spark.HashPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.storage.StorageLevel - import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl._ import org.apache.spark.graphx.util.BytecodeUtils @@ -193,37 +192,44 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( case (pid, edgePartition) => // Choose scan method val activeFraction = edgePartition.numActives.getOrElse(0) / edgePartition.indexSize.toFloat - val edgeIter = activeDirectionOpt match { + activeDirectionOpt match { case Some(EdgeDirection.Both) => if (activeFraction < 0.8) { - edgePartition.indexIterator(srcVertexId => edgePartition.isActive(srcVertexId)) - .filter(e => edgePartition.isActive(e.dstId)) + edgePartition.mapReduceTripletsWithIndex( + mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr, + srcId => edgePartition.isActive(srcId), + dstId => edgePartition.isActive(dstId)) } else { - edgePartition.iterator.filter(e => - edgePartition.isActive(e.srcId) && edgePartition.isActive(e.dstId)) + edgePartition.mapReduceTriplets( + mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr, + (srcId, dstId) => edgePartition.isActive(srcId) && edgePartition.isActive(dstId)) } case Some(EdgeDirection.Either) => // TODO: Because we only have a clustered index on the source vertex ID, we can't filter // the index here. Instead we have to scan all edges and then do the filter. - edgePartition.iterator.filter(e => - edgePartition.isActive(e.srcId) || edgePartition.isActive(e.dstId)) + edgePartition.mapReduceTriplets( + mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr, + (srcId, dstId) => edgePartition.isActive(srcId) || edgePartition.isActive(dstId)) case Some(EdgeDirection.Out) => if (activeFraction < 0.8) { - edgePartition.indexIterator(srcVertexId => edgePartition.isActive(srcVertexId)) + edgePartition.mapReduceTripletsWithIndex( + mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr, + srcId => edgePartition.isActive(srcId), + dstId => true) } else { - edgePartition.iterator.filter(e => edgePartition.isActive(e.srcId)) + edgePartition.mapReduceTriplets( + mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr, + (srcId, dstId) => edgePartition.isActive(srcId)) } case Some(EdgeDirection.In) => - edgePartition.iterator.filter(e => edgePartition.isActive(e.dstId)) + edgePartition.mapReduceTriplets( + mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr, + (srcId, dstId) => edgePartition.isActive(dstId)) case _ => // None - edgePartition.iterator + edgePartition.mapReduceTriplets( + mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr, + (srcId, dstId) => true) } - - // Scan edges and run the map function - val mapOutputs = edgePartition.upgradeIterator(edgeIter, mapUsesSrcAttr, mapUsesDstAttr) - .flatMap(mapFunc(_)) - // Note: This doesn't allow users to send messages to arbitrary vertices. - edgePartition.vertices.aggregateUsingIndex(mapOutputs, reduceFunc).iterator }).setName("GraphImpl.mapReduceTriplets - preAgg") // do the final reduction reusing the index map @@ -306,9 +312,7 @@ object GraphImpl { vertices: VertexRDD[VD], edges: EdgeRDD[ED, _]): GraphImpl[VD, ED] = { // Convert the vertex partitions in edges to the correct type - val newEdges = edges.mapEdgePartitions( - (pid, part) => part.withVertices(part.vertices.map( - (vid, attr) => null.asInstanceOf[VD]))) + val newEdges = edges.mapEdgePartitions((pid, part) => part.clearVertices[VD]) GraphImpl.fromExistingRDDs(vertices, newEdges) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index b27485953f719..4bd4d8e6b9ddf 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -74,11 +74,9 @@ object RoutingTablePartition { // Determine which positions each vertex id appears in using a map where the low 2 bits // represent src and dst val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, Byte] - edgePartition.srcIds.iterator.foreach { srcId => - map.changeValue(srcId, 0x1, (b: Byte) => (b | 0x1).toByte) - } - edgePartition.dstIds.iterator.foreach { dstId => - map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte) + edgePartition.iterator.foreach { e => + map.changeValue(e.srcId, 0x1, (b: Byte) => (b | 0x1).toByte) + map.changeValue(e.dstId, 0x2, (b: Byte) => (b | 0x2).toByte) } map.iterator.map { vidAndPosition => val vid = vidAndPosition._1 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 6506bac73d71c..697afef29029c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -118,7 +118,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { // Each vertex should be replicated to at most 2 * sqrt(p) partitions val partitionSets = partitionedGraph.edges.partitionsRDD.mapPartitions { iter => val part = iter.next()._2 - Iterator((part.srcIds ++ part.dstIds).toSet) + Iterator((part.iterator.flatMap(e => Iterator(e.srcId, e.dstId))).toSet) }.collect if (!verts.forall(id => partitionSets.count(_.contains(id)) <= bound)) { val numFailures = verts.count(id => partitionSets.count(_.contains(id)) > bound) @@ -130,7 +130,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { // This should not be true for the default hash partitioning val partitionSetsUnpartitioned = graph.edges.partitionsRDD.mapPartitions { iter => val part = iter.next()._2 - Iterator((part.srcIds ++ part.dstIds).toSet) + Iterator((part.iterator.flatMap(e => Iterator(e.srcId, e.dstId))).toSet) }.collect assert(verts.exists(id => partitionSetsUnpartitioned.count(_.contains(id)) > bound)) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index db1dac6160080..b99075c301000 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -82,29 +82,6 @@ class EdgePartitionSuite extends FunSuite { assert(edgePartition.groupEdges(_ + _).iterator.map(_.copy()).toList === groupedEdges) } - test("upgradeIterator") { - val edges = List((0, 1, 0), (1, 0, 0)) - val verts = List((0L, 1), (1L, 2)) - val part = makeEdgePartition(edges).updateVertices(verts.iterator) - assert(part.upgradeIterator(part.iterator).map(_.toTuple).toList === - part.tripletIterator().toList.map(_.toTuple)) - } - - test("indexIterator") { - val edgesFrom0 = List(Edge(0, 1, 0)) - val edgesFrom1 = List(Edge(1, 0, 0), Edge(1, 2, 0)) - val sortedEdges = edgesFrom0 ++ edgesFrom1 - val builder = new EdgePartitionBuilder[Int, Nothing] - for (e <- Random.shuffle(sortedEdges)) { - builder.add(e.srcId, e.dstId, e.attr) - } - - val edgePartition = builder.toEdgePartition - assert(edgePartition.iterator.map(_.copy()).toList === sortedEdges) - assert(edgePartition.indexIterator(_ == 0).map(_.copy()).toList === edgesFrom0) - assert(edgePartition.indexIterator(_ == 1).map(_.copy()).toList === edgesFrom1) - } - test("innerJoin") { val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0)) val bList = List((0, 1, 0), (1, 0, 0), (1, 1, 0), (3, 4, 0), (5, 5, 0)) @@ -135,11 +112,13 @@ class EdgePartitionSuite extends FunSuite { for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) { val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a)) - assert(aSer.srcIds.toList === a.srcIds.toList) - assert(aSer.dstIds.toList === a.dstIds.toList) + assert(aSer.localSrcIds.toList === a.localSrcIds.toList) + assert(aSer.localDstIds.toList === a.localDstIds.toList) assert(aSer.data.toList === a.data.toList) assert(aSer.index != null) - assert(aSer.vertices.iterator.toSet === a.vertices.iterator.toSet) + assert(aSer.global2local != null) + assert(aSer.local2global.toList === a.local2global.toList) + assert(aSer.vertexAttrs.toList === a.vertexAttrs.toList) } } } From b567be2825ea22f2e61fbd9caa34940f5bc404df Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Tue, 4 Nov 2014 01:56:48 -0800 Subject: [PATCH 2/7] iter.foreach -> while loop --- .../scala/org/apache/spark/graphx/impl/EdgePartition.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index 52661aa5f1d3c..8079de96796cc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -65,7 +65,7 @@ class EdgePartition[ /** Return a new `EdgePartition` with the specified active set, provided as an iterator. */ def withActiveSet(iter: Iterator[VertexId]): EdgePartition[ED, VD] = { val activeSet = new VertexSet - iter.foreach(activeSet.add(_)) + while (iter.hasNext) { activeSet.add(iter.next()) } new EdgePartition( localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, Some(activeSet)) @@ -81,7 +81,8 @@ class EdgePartition[ def updateVertices(iter: Iterator[(VertexId, VD)]): EdgePartition[ED, VD] = { val newVertexAttrs = new Array[VD](vertexAttrs.length) System.arraycopy(vertexAttrs, 0, newVertexAttrs, 0, vertexAttrs.length) - iter.foreach { kv => + while (iter.hasNext) { + val kv = iter.next() newVertexAttrs(global2local(kv._1)) = kv._2 } new EdgePartition( From c85076de62b4c3344c443d4e85fce8fc47274aac Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Tue, 4 Nov 2014 01:58:00 -0800 Subject: [PATCH 3/7] Readability improvements --- .../spark/graphx/impl/EdgePartition.scala | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index 8079de96796cc..363c0fddcc1f7 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -98,9 +98,9 @@ class EdgePartition[ activeSet) } - def srcIds(i: Int): VertexId = local2global(localSrcIds(i)) + private def srcIds(pos: Int): VertexId = local2global(localSrcIds(pos)) - def dstIds(i: Int): VertexId = local2global(localDstIds(i)) + private def dstIds(pos: Int): VertexId = local2global(localDstIds(pos)) /** Look up vid in activeSet, throwing an exception if it is None. */ def isActive(vid: VertexId): Boolean = { @@ -231,23 +231,34 @@ class EdgePartition[ global2local, local2global, vertexAttrs) var currSrcId: VertexId = null.asInstanceOf[VertexId] var currDstId: VertexId = null.asInstanceOf[VertexId] + var currLocalSrcId = -1 + var currLocalDstId = -1 var currAttr: ED = null.asInstanceOf[ED] + // Iterate through the edges, accumulating runs of identical edges using the curr* variables and + // releasing them to the builder when we see the beginning of the next run var i = 0 while (i < size) { if (i > 0 && currSrcId == srcIds(i) && currDstId == dstIds(i)) { + // This edge should be accumulated into the existing run currAttr = merge(currAttr, data(i)) } else { + // This edge starts a new run of edges if (i > 0) { - builder.add(currSrcId, currDstId, localSrcIds(i - 1), localDstIds(i - 1), currAttr) + // First release the existing run to the builder + builder.add(currSrcId, currDstId, currLocalSrcId, currLocalDstId, currAttr) } + // Then start accumulating for a new run currSrcId = srcIds(i) currDstId = dstIds(i) + currLocalSrcId = localSrcIds(i) + currLocalDstId = localDstIds(i) currAttr = data(i) } i += 1 } + // Finally, release the last accumulated run if (size > 0) { - builder.add(currSrcId, currDstId, localSrcIds(i - 1), localDstIds(i - 1), currAttr) + builder.add(currSrcId, currDstId, currLocalSrcId, currLocalDstId, currAttr) } builder.toEdgePartition.withActiveSet(activeSet) } From e0f8ecc7b678de2b011650ed96b974369730947e Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Tue, 4 Nov 2014 01:58:23 -0800 Subject: [PATCH 4/7] Take activeSet in ExistingEdgePartitionBuilder Also rename VertexPreservingEdgePartitionBuilder to ExistingEdgePartitionBuilder to better reflect its usage. --- .../spark/graphx/impl/EdgePartition.scala | 30 ++++++++----------- .../graphx/impl/EdgePartitionBuilder.scala | 13 ++++---- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index 363c0fddcc1f7..a0ab8a1becb21 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -71,12 +71,6 @@ class EdgePartition[ Some(activeSet)) } - /** Return a new `EdgePartition` with the specified active set. */ - def withActiveSet(activeSet: Option[VertexSet]): EdgePartition[ED, VD] = { - new EdgePartition( - localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet) - } - /** Return a new `EdgePartition` with updates to vertex attributes specified in `iter`. */ def updateVertices(iter: Iterator[(VertexId, VD)]): EdgePartition[ED, VD] = { val newVertexAttrs = new Array[VD](vertexAttrs.length) @@ -116,8 +110,8 @@ class EdgePartition[ * @return a new edge partition with all edges reversed. */ def reverse: EdgePartition[ED, VD] = { - val builder = new VertexPreservingEdgePartitionBuilder( - global2local, local2global, vertexAttrs, size)(classTag[ED], classTag[VD]) + val builder = new ExistingEdgePartitionBuilder[ED, VD]( + global2local, local2global, vertexAttrs, activeSet, size) var i = 0 while (i < size) { val localSrcId = localSrcIds(i) @@ -128,7 +122,7 @@ class EdgePartition[ builder.add(dstId, srcId, localDstId, localSrcId, attr) i += 1 } - builder.toEdgePartition.withActiveSet(activeSet) + builder.toEdgePartition } /** @@ -189,8 +183,8 @@ class EdgePartition[ def filter( epred: EdgeTriplet[VD, ED] => Boolean, vpred: (VertexId, VD) => Boolean): EdgePartition[ED, VD] = { - val builder = new VertexPreservingEdgePartitionBuilder[ED, VD]( - global2local, local2global, vertexAttrs) + val builder = new ExistingEdgePartitionBuilder[ED, VD]( + global2local, local2global, vertexAttrs, activeSet) var i = 0 while (i < size) { // The user sees the EdgeTriplet, so we can't reuse it and must create one per edge. @@ -207,7 +201,7 @@ class EdgePartition[ } i += 1 } - builder.toEdgePartition.withActiveSet(activeSet) + builder.toEdgePartition } /** @@ -227,8 +221,8 @@ class EdgePartition[ * @return a new edge partition without duplicate edges */ def groupEdges(merge: (ED, ED) => ED): EdgePartition[ED, VD] = { - val builder = new VertexPreservingEdgePartitionBuilder[ED, VD]( - global2local, local2global, vertexAttrs) + val builder = new ExistingEdgePartitionBuilder[ED, VD]( + global2local, local2global, vertexAttrs, activeSet) var currSrcId: VertexId = null.asInstanceOf[VertexId] var currDstId: VertexId = null.asInstanceOf[VertexId] var currLocalSrcId = -1 @@ -260,7 +254,7 @@ class EdgePartition[ if (size > 0) { builder.add(currSrcId, currDstId, currLocalSrcId, currLocalDstId, currAttr) } - builder.toEdgePartition.withActiveSet(activeSet) + builder.toEdgePartition } /** @@ -276,8 +270,8 @@ class EdgePartition[ def innerJoin[ED2: ClassTag, ED3: ClassTag] (other: EdgePartition[ED2, _]) (f: (VertexId, VertexId, ED, ED2) => ED3): EdgePartition[ED3, VD] = { - val builder = new VertexPreservingEdgePartitionBuilder[ED3, VD]( - global2local, local2global, vertexAttrs) + val builder = new ExistingEdgePartitionBuilder[ED3, VD]( + global2local, local2global, vertexAttrs, activeSet) var i = 0 var j = 0 // For i = index of each edge in `this`... @@ -296,7 +290,7 @@ class EdgePartition[ } i += 1 } - builder.toEdgePartition.withActiveSet(activeSet) + builder.toEdgePartition } /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 675247d1686a9..95a9dca3d16e7 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -77,14 +77,15 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla /** * Constructs an EdgePartition from an existing EdgePartition with the same vertex set. This enables - * reuse of the local vertex ids. + * reuse of the local vertex ids. Intended for internal use in EdgePartition only. */ -private[graphx] -class VertexPreservingEdgePartitionBuilder[ +private[impl] +class ExistingEdgePartitionBuilder[ @specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag]( global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int], local2global: Array[VertexId], vertexAttrs: Array[VD], + activeSet: Option[VertexSet], size: Int = 64) { var edges = new PrimitiveVector[EdgeWithLocalIds[ED]](size) @@ -119,14 +120,14 @@ class VertexPreservingEdgePartitionBuilder[ } new EdgePartition( - localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs) + localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet) } } -private[graphx] case class EdgeWithLocalIds[@specialized ED]( +private[impl] case class EdgeWithLocalIds[@specialized ED]( srcId: VertexId, dstId: VertexId, localSrcId: Int, localDstId: Int, attr: ED) -private[graphx] object EdgeWithLocalIds { +private[impl] object EdgeWithLocalIds { implicit def lexicographicOrdering[ED] = new Ordering[EdgeWithLocalIds[ED]] { override def compare(a: EdgeWithLocalIds[ED], b: EdgeWithLocalIds[ED]): Int = { if (a.srcId == b.srcId) { From 194a2df94768be9c08ed50654170bad937bd115a Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Tue, 4 Nov 2014 02:03:34 -0800 Subject: [PATCH 5/7] Test triplet iterator in EdgePartition serialization test --- .../apache/spark/graphx/impl/EdgePartitionSuite.scala | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index b99075c301000..c7a59990ce8e7 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -103,7 +103,7 @@ class EdgePartitionSuite extends FunSuite { } test("serialization") { - val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0)) + val aList = List((0, 1, 1), (1, 0, 2), (1, 2, 3), (5, 4, 4), (5, 5, 5)) val a: EdgePartition[Int, Int] = makeEdgePartition(aList) val javaSer = new JavaSerializer(new SparkConf()) val conf = new SparkConf() @@ -112,13 +112,8 @@ class EdgePartitionSuite extends FunSuite { for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) { val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a)) - assert(aSer.localSrcIds.toList === a.localSrcIds.toList) - assert(aSer.localDstIds.toList === a.localDstIds.toList) - assert(aSer.data.toList === a.data.toList) + assert(aSer.tripletIterator().toList === a.tripletIterator().toList) assert(aSer.index != null) - assert(aSer.global2local != null) - assert(aSer.local2global.toList === a.local2global.toList) - assert(aSer.vertexAttrs.toList === a.vertexAttrs.toList) } } } From 1e80aca308463b0ec7dbeee58c7d1935ebb59e77 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Sat, 1 Nov 2014 00:01:21 -0700 Subject: [PATCH 6/7] Add aggregateMessages, which supersedes mapReduceTriplets aggregateMessages enables neighborhood computation similarly to mapReduceTriplets, but it introduces two API improvements: 1. Messages are sent using an imperative interface based on EdgeContext rather than by returning an iterator of messages. This is more efficient, providing a 20.2% speedup on PageRank over apache/spark#3054 (uk-2007-05 graph, 10 iterations, 16 r3.2xlarge machines, sped up from 403 s to 322 s). 2. Rather than attempting bytecode inspection, the required triplet fields must be explicitly specified by the user by passing a TripletFields object. This fixes SPARK-3936. Subsumes apache/spark#2815. --- .../org/apache/spark/graphx/EdgeContext.scala | 51 ++++++++ .../scala/org/apache/spark/graphx/Graph.scala | 68 +++++++++- .../org/apache/spark/graphx/GraphOps.scala | 85 +++++++------ .../apache/spark/graphx/TripletFields.scala | 59 +++++++++ .../spark/graphx/impl/EdgePartition.scala | 118 ++++++++++-------- .../apache/spark/graphx/impl/GraphImpl.scala | 67 ++++++---- .../apache/spark/graphx/lib/PageRank.scala | 6 +- .../apache/spark/graphx/lib/SVDPlusPlus.scala | 46 +++---- .../spark/graphx/lib/TriangleCount.scala | 19 +-- .../org/apache/spark/graphx/GraphSuite.scala | 15 +++ 10 files changed, 374 insertions(+), 160 deletions(-) create mode 100644 graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala create mode 100644 graphx/src/main/scala/org/apache/spark/graphx/TripletFields.scala diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala new file mode 100644 index 0000000000000..ad85376cec8ac --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx + +/** + * Represents an edge along with its neighboring vertices and allows sending messages along the + * edge. Used in [[Graph#aggregateMessages]]. + */ +trait EdgeContext[VD, ED, A] { + /** The vertex id of the edge's source vertex. */ + def srcId: VertexId + /** The vertex id of the edge's destination vertex. */ + def dstId: VertexId + /** The vertex attribute of the edge's source vertex. */ + def srcAttr: VD + /** The vertex attribute of the edge's destination vertex. */ + def dstAttr: VD + /** The attribute associated with the edge. */ + def attr: ED + + /** Sends a message to the source vertex. */ + def sendToSrc(msg: A): Unit + /** Sends a message to the destination vertex. */ + def sendToDst(msg: A): Unit + + /** Converts the edge and vertex properties into an [[EdgeTriplet]] for convenience. */ + def toEdgeTriplet: EdgeTriplet[VD, ED] = { + val et = new EdgeTriplet[VD, ED] + et.srcId = srcId + et.srcAttr = srcAttr + et.dstId = dstId + et.dstAttr = dstAttr + et.attr = attr + et + } +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index fa4b891754c40..c0c7ca19d3b76 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -195,6 +195,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * the underlying index structures can be reused. * * @param map the function from an edge object to a new edge value. + * @param tripletFields which fields should be included in the edge triplet passed to the map + * function. If not all fields are needed, specifying this can improve performance. * * @tparam ED2 the new edge data type * @@ -207,8 +209,10 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * }}} * */ - def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = { - mapTriplets((pid, iter) => iter.map(map)) + def mapTriplets[ED2: ClassTag]( + map: EdgeTriplet[VD, ED] => ED2, + tripletFields: TripletFields = TripletFields.All): Graph[VD, ED2] = { + mapTriplets((pid, iter) => iter.map(map), tripletFields) } /** @@ -223,12 +227,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * the underlying index structures can be reused. * * @param map the iterator transform + * @param tripletFields which fields should be included in the edge triplet passed to the map + * function. If not all fields are needed, specifying this can improve performance. * * @tparam ED2 the new edge data type * */ - def mapTriplets[ED2: ClassTag](map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]) - : Graph[VD, ED2] + def mapTriplets[ED2: ClassTag]( + map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2], + tripletFields: TripletFields): Graph[VD, ED2] /** * Reverses all edges in the graph. If this graph contains an edge from a to b then the returned @@ -287,6 +294,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * "sent" to either vertex in the edge. The `reduceFunc` is then used to combine the output of * the map phase destined to each vertex. * + * This function is deprecated in 1.2.0 because of SPARK-3936. Use aggregateMessages instead. + * * @tparam A the type of "message" to be sent to each vertex * * @param mapFunc the user defined map function which returns 0 or @@ -319,6 +328,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * predicate or implement PageRank. * */ + @deprecated("use aggregateMessages", "1.2.0") def mapReduceTriplets[A: ClassTag]( mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], reduceFunc: (A, A) => A, @@ -326,8 +336,54 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab : VertexRDD[A] /** - * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. The - * input table should contain at most one entry for each vertex. If no entry in `other` is + * Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied + * `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be + * sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages + * destined to the same vertex. + * + * @tparam A the type of message to be sent to each vertex + * + * @param sendMsg runs on each edge, sending messages to neighboring vertices using the + * [[EdgeContext]]. + * @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This + * combiner should be commutative and associative. + * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the + * `sendMsg` function. If not all fields are needed, specifying this can improve performance. + * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if + * desired. This is done by specifying a set of "active" vertices and an edge direction. The + * `sendMsg` function will then run on only edges connected to active vertices by edges in the + * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with + * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges + * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be + * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg` + * will be run on edges with *both* vertices in the active set. The active set must have the + * same index as the graph's vertices. + * + * @example We can use this function to compute the in-degree of each + * vertex + * {{{ + * val rawGraph: Graph[_, _] = Graph.textFile("twittergraph") + * val inDeg: RDD[(VertexId, Int)] = + * aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) + * }}} + * + * @note By expressing computation at the edge level we achieve + * maximum parallelism. This is one of the core functions in the + * Graph API in that enables neighborhood level computation. For + * example this function can be used to count neighbors satisfying a + * predicate or implement PageRank. + * + */ + def aggregateMessages[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields = TripletFields.All, + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None) + : VertexRDD[A] + + /** + * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. + * The input table should contain at most one entry for each vertex. If no entry in `other` is * provided for a particular vertex in the graph, the map function receives `None`. * * @tparam U the type of entry in the table of updates diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index d0dd45dba618e..d5150382d599b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -69,11 +69,12 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali */ private def degreesRDD(edgeDirection: EdgeDirection): VertexRDD[Int] = { if (edgeDirection == EdgeDirection.In) { - graph.mapReduceTriplets(et => Iterator((et.dstId,1)), _ + _) + graph.aggregateMessages(_.sendToDst(1), _ + _, TripletFields.None) } else if (edgeDirection == EdgeDirection.Out) { - graph.mapReduceTriplets(et => Iterator((et.srcId,1)), _ + _) + graph.aggregateMessages(_.sendToSrc(1), _ + _, TripletFields.None) } else { // EdgeDirection.Either - graph.mapReduceTriplets(et => Iterator((et.srcId,1), (et.dstId,1)), _ + _) + graph.aggregateMessages(ctx => { ctx.sendToSrc(1); ctx.sendToDst(1) }, _ + _, + TripletFields.None) } } @@ -88,18 +89,17 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] = { val nbrs = if (edgeDirection == EdgeDirection.Either) { - graph.mapReduceTriplets[Array[VertexId]]( - mapFunc = et => Iterator((et.srcId, Array(et.dstId)), (et.dstId, Array(et.srcId))), - reduceFunc = _ ++ _ - ) + graph.aggregateMessages[Array[VertexId]]( + ctx => { ctx.sendToSrc(Array(ctx.dstId)); ctx.sendToDst(Array(ctx.srcId)) }, + _ ++ _, TripletFields.None) } else if (edgeDirection == EdgeDirection.Out) { - graph.mapReduceTriplets[Array[VertexId]]( - mapFunc = et => Iterator((et.srcId, Array(et.dstId))), - reduceFunc = _ ++ _) + graph.aggregateMessages[Array[VertexId]]( + ctx => ctx.sendToSrc(Array(ctx.dstId)), + _ ++ _, TripletFields.None) } else if (edgeDirection == EdgeDirection.In) { - graph.mapReduceTriplets[Array[VertexId]]( - mapFunc = et => Iterator((et.dstId, Array(et.srcId))), - reduceFunc = _ ++ _) + graph.aggregateMessages[Array[VertexId]]( + ctx => ctx.sendToDst(Array(ctx.srcId)), + _ ++ _, TripletFields.None) } else { throw new SparkException("It doesn't make sense to collect neighbor ids without a " + "direction. (EdgeDirection.Both is not supported; use EdgeDirection.Either instead.)") @@ -122,22 +122,27 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * @return the vertex set of neighboring vertex attributes for each vertex */ def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = { - val nbrs = graph.mapReduceTriplets[Array[(VertexId,VD)]]( - edge => { - val msgToSrc = (edge.srcId, Array((edge.dstId, edge.dstAttr))) - val msgToDst = (edge.dstId, Array((edge.srcId, edge.srcAttr))) - edgeDirection match { - case EdgeDirection.Either => Iterator(msgToSrc, msgToDst) - case EdgeDirection.In => Iterator(msgToDst) - case EdgeDirection.Out => Iterator(msgToSrc) - case EdgeDirection.Both => - throw new SparkException("collectNeighbors does not support EdgeDirection.Both. Use" + - "EdgeDirection.Either instead.") - } - }, - (a, b) => a ++ b) - - graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) => + val nbrs = edgeDirection match { + case EdgeDirection.Either => + graph.aggregateMessages[Array[(VertexId,VD)]]( + ctx => { + ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))) + ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))) + }, + (a, b) => a ++ b, TripletFields.SrcDstOnly) + case EdgeDirection.In => + graph.aggregateMessages[Array[(VertexId,VD)]]( + ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))), + (a, b) => a ++ b, TripletFields.SrcOnly) + case EdgeDirection.Out => + graph.aggregateMessages[Array[(VertexId,VD)]]( + ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))), + (a, b) => a ++ b, TripletFields.DstOnly) + case EdgeDirection.Both => + throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" + + "EdgeDirection.Either instead.") + } + graph.vertices.leftJoin(nbrs) { (vid, vdata, nbrsOpt) => nbrsOpt.getOrElse(Array.empty[(VertexId, VD)]) } } // end of collectNeighbor @@ -160,18 +165,20 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def collectEdges(edgeDirection: EdgeDirection): VertexRDD[Array[Edge[ED]]] = { edgeDirection match { case EdgeDirection.Either => - graph.mapReduceTriplets[Array[Edge[ED]]]( - edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))), - (edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), - (a, b) => a ++ b) + graph.aggregateMessages[Array[Edge[ED]]]( + ctx => { + ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))) + ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))) + }, + (a, b) => a ++ b, TripletFields.EdgeOnly) case EdgeDirection.In => - graph.mapReduceTriplets[Array[Edge[ED]]]( - edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), - (a, b) => a ++ b) + graph.aggregateMessages[Array[Edge[ED]]]( + ctx => ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))), + (a, b) => a ++ b, TripletFields.EdgeOnly) case EdgeDirection.Out => - graph.mapReduceTriplets[Array[Edge[ED]]]( - edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), - (a, b) => a ++ b) + graph.aggregateMessages[Array[Edge[ED]]]( + ctx => ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))), + (a, b) => a ++ b, TripletFields.EdgeOnly) case EdgeDirection.Both => throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" + "EdgeDirection.Either instead.") diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.scala b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.scala new file mode 100644 index 0000000000000..e92e2763a0c06 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx + +/** + * Represents a subset of the fields of an [[EdgeTriplet]] or [[EdgeContext]]. This allows the + * system to populate only those fields for efficiency. + */ +class TripletFields private ( + val useSrc: Boolean, + val useDst: Boolean, + val useEdge: Boolean) + extends Serializable { + private def this() = this(true, true, true) +} + +/** + * Exposes all possible [[TripletFields]] objects. + */ +object TripletFields { + final val None = new TripletFields(useSrc = false, useDst = false, useEdge = false) + final val EdgeOnly = new TripletFields(useSrc = false, useDst = false, useEdge = true) + final val SrcOnly = new TripletFields(useSrc = true, useDst = false, useEdge = false) + final val DstOnly = new TripletFields(useSrc = false, useDst = true, useEdge = false) + final val SrcDstOnly = new TripletFields(useSrc = true, useDst = true, useEdge = false) + final val SrcAndEdge = new TripletFields(useSrc = true, useDst = false, useEdge = true) + final val Src = SrcAndEdge + final val DstAndEdge = new TripletFields(useSrc = false, useDst = true, useEdge = true) + final val Dst = DstAndEdge + final val All = new TripletFields(useSrc = true, useDst = true, useEdge = true) + + /** Returns the appropriate [[TripletFields]] object. */ + private[graphx] def apply(useSrc: Boolean, useDst: Boolean, useEdge: Boolean) = + (useSrc, useDst, useEdge) match { + case (false, false, false) => TripletFields.None + case (false, false, true) => EdgeOnly + case (true, false, false) => SrcOnly + case (false, true, false) => DstOnly + case (true, true, false) => SrcDstOnly + case (true, false, true) => SrcAndEdge + case (false, true, true) => DstAndEdge + case (true, true, true) => All + } +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index a0ab8a1becb21..86ee3923151b4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -340,24 +340,23 @@ class EdgePartition[ * Send messages along edges and aggregate them at the receiving vertices. Implemented by scanning * all edges sequentially and filtering them with `idPred`. * - * @param mapFunc the edge map function which generates messages to neighboring vertices - * @param reduceFunc the combiner applied to messages destined to the same vertex - * @param mapUsesSrcAttr whether or not `mapFunc` uses the edge's source vertex attribute - * @param mapUsesDstAttr whether or not `mapFunc` uses the edge's destination vertex attribute + * @param sendMsg generates messages to neighboring vertices of an edge + * @param mergeMsg the combiner applied to messages destined to the same vertex + * @param sendMsgUsesSrcAttr whether or not `mapFunc` uses the edge's source vertex attribute + * @param sendMsgUsesDstAttr whether or not `mapFunc` uses the edge's destination vertex attribute * @param idPred a predicate to filter edges based on their source and destination vertex ids * * @return iterator aggregated messages keyed by the receiving vertex id */ - def mapReduceTriplets[A: ClassTag]( - mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], - reduceFunc: (A, A) => A, - mapUsesSrcAttr: Boolean, - mapUsesDstAttr: Boolean, + def aggregateMessages[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, idPred: (VertexId, VertexId) => Boolean): Iterator[(VertexId, A)] = { val aggregates = new Array[A](vertexAttrs.length) val bitset = new BitSet(vertexAttrs.length) - var edge = new EdgeTriplet[VD, ED] + var ctx = new AggregatingEdgeContext[VD, ED, A](mergeMsg, aggregates, bitset) var i = 0 while (i < size) { val localSrcId = localSrcIds(i) @@ -365,23 +364,14 @@ class EdgePartition[ val localDstId = localDstIds(i) val dstId = local2global(localDstId) if (idPred(srcId, dstId)) { - edge.srcId = srcId - edge.dstId = dstId - edge.attr = data(i) - if (mapUsesSrcAttr) { edge.srcAttr = vertexAttrs(localSrcId) } - if (mapUsesDstAttr) { edge.dstAttr = vertexAttrs(localDstId) } - - mapFunc(edge).foreach { kv => - val globalId = kv._1 - val msg = kv._2 - val localId = if (globalId == srcId) localSrcId else localDstId - if (bitset.get(localId)) { - aggregates(localId) = reduceFunc(aggregates(localId), msg) - } else { - aggregates(localId) = msg - bitset.set(localId) - } - } + ctx.localSrcId = localSrcId + ctx.localDstId = localDstId + ctx.srcId = srcId + ctx.dstId = dstId + ctx.attr = data(i) + if (tripletFields.useSrc) { ctx.srcAttr = vertexAttrs(localSrcId) } + if (tripletFields.useDst) { ctx.dstAttr = vertexAttrs(localDstId) } + sendMsg(ctx) } i += 1 } @@ -394,53 +384,41 @@ class EdgePartition[ * filtering the source vertex index with `srcIdPred`, then scanning edge clusters and filtering * with `dstIdPred`. Both `srcIdPred` and `dstIdPred` must match for an edge to run. * - * @param mapFunc the edge map function which generates messages to neighboring vertices - * @param reduceFunc the combiner applied to messages destined to the same vertex - * @param mapUsesSrcAttr whether or not `mapFunc` uses the edge's source vertex attribute - * @param mapUsesDstAttr whether or not `mapFunc` uses the edge's destination vertex attribute + * @param sendMsg generates messages to neighboring vertices of an edge + * @param mergeMsg the combiner applied to messages destined to the same vertex * @param srcIdPred a predicate to filter edges based on their source vertex id * @param dstIdPred a predicate to filter edges based on their destination vertex id * * @return iterator aggregated messages keyed by the receiving vertex id */ - def mapReduceTripletsWithIndex[A: ClassTag]( - mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], - reduceFunc: (A, A) => A, - mapUsesSrcAttr: Boolean, - mapUsesDstAttr: Boolean, + def aggregateMessagesWithIndex[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, srcIdPred: VertexId => Boolean, dstIdPred: VertexId => Boolean): Iterator[(VertexId, A)] = { val aggregates = new Array[A](vertexAttrs.length) val bitset = new BitSet(vertexAttrs.length) - var edge = new EdgeTriplet[VD, ED] + var ctx = new AggregatingEdgeContext[VD, ED, A](mergeMsg, aggregates, bitset) index.iterator.foreach { cluster => val clusterSrcId = cluster._1 val clusterPos = cluster._2 val clusterLocalSrcId = localSrcIds(clusterPos) if (srcIdPred(clusterSrcId)) { var pos = clusterPos - edge.srcId = clusterSrcId - if (mapUsesSrcAttr) { edge.srcAttr = vertexAttrs(clusterLocalSrcId) } + ctx.srcId = clusterSrcId + ctx.localSrcId = clusterLocalSrcId + if (tripletFields.useSrc) { ctx.srcAttr = vertexAttrs(clusterLocalSrcId) } while (pos < size && localSrcIds(pos) == clusterLocalSrcId) { val localDstId = localDstIds(pos) val dstId = local2global(localDstId) if (dstIdPred(dstId)) { - edge.dstId = dstId - edge.attr = data(pos) - if (mapUsesDstAttr) { edge.dstAttr = vertexAttrs(localDstId) } - - mapFunc(edge).foreach { kv => - val globalId = kv._1 - val msg = kv._2 - val localId = if (globalId == clusterSrcId) clusterLocalSrcId else localDstId - if (bitset.get(localId)) { - aggregates(localId) = reduceFunc(aggregates(localId), msg) - } else { - aggregates(localId) = msg - bitset.set(localId) - } - } + ctx.dstId = dstId + ctx.localDstId = localDstId + ctx.attr = data(pos) + if (tripletFields.useDst) { ctx.dstAttr = vertexAttrs(localDstId) } + sendMsg(ctx) } pos += 1 } @@ -450,3 +428,35 @@ class EdgePartition[ bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) } } } + +private class AggregatingEdgeContext[VD, ED, A]( + mergeMsg: (A, A) => A, + aggregates: Array[A], + bitset: BitSet) + extends EdgeContext[VD, ED, A] { + + var srcId: VertexId = _ + var dstId: VertexId = _ + var srcAttr: VD = _ + var dstAttr: VD = _ + var attr: ED = _ + + var localSrcId: Int = _ + var localDstId: Int = _ + + override def sendToSrc(msg: A) { + send(localSrcId, msg) + } + override def sendToDst(msg: A) { + send(localDstId, msg) + } + + private def send(localId: Int, msg: A) { + if (bitset.get(localId)) { + aggregates(localId) = mergeMsg(aggregates(localId), msg) + } else { + aggregates(localId) = msg + bitset.set(localId) + } + } +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 1188e2ad91821..bcbb22b9100dc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -126,13 +126,12 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( } override def mapTriplets[ED2: ClassTag]( - f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]): Graph[VD, ED2] = { + f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2], + tripletFields: TripletFields): Graph[VD, ED2] = { vertices.cache() - val mapUsesSrcAttr = accessesVertexAttr(f, "srcAttr") - val mapUsesDstAttr = accessesVertexAttr(f, "dstAttr") - replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr) + replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst) val newEdges = replicatedVertexView.edges.mapEdgePartitions { (pid, part) => - part.map(f(pid, part.tripletIterator(mapUsesSrcAttr, mapUsesDstAttr))) + part.map(f(pid, part.tripletIterator(tripletFields.useSrc, tripletFields.useDst))) } new GraphImpl(vertices, replicatedVertexView.withEdges(newEdges)) } @@ -170,15 +169,38 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( override def mapReduceTriplets[A: ClassTag]( mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], reduceFunc: (A, A) => A, + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = { + + def sendMsg(ctx: EdgeContext[VD, ED, A]) { + mapFunc(ctx.toEdgeTriplet).foreach { kv => + val id = kv._1 + val msg = kv._2 + if (id == ctx.srcId) { + ctx.sendToSrc(msg) + } else { + assert(id == ctx.dstId) + ctx.sendToDst(msg) + } + } + } + + val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr") + val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr") + val tripletFields = TripletFields(mapUsesSrcAttr, mapUsesDstAttr, useEdge = true) + + aggregateMessages(sendMsg, reduceFunc, tripletFields, activeSetOpt) + } + + override def aggregateMessages[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = { vertices.cache() - // For each vertex, replicate its attribute only to partitions where it is // in the relevant position in an edge. - val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr") - val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr") - replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr) + replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst) val view = activeSetOpt match { case Some((activeSet, _)) => replicatedVertexView.withActiveSet(activeSet) @@ -195,46 +217,39 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( activeDirectionOpt match { case Some(EdgeDirection.Both) => if (activeFraction < 0.8) { - edgePartition.mapReduceTripletsWithIndex( - mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr, + edgePartition.aggregateMessagesWithIndex(sendMsg, mergeMsg, tripletFields, srcId => edgePartition.isActive(srcId), dstId => edgePartition.isActive(dstId)) } else { - edgePartition.mapReduceTriplets( - mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr, + edgePartition.aggregateMessages(sendMsg, mergeMsg, tripletFields, (srcId, dstId) => edgePartition.isActive(srcId) && edgePartition.isActive(dstId)) } case Some(EdgeDirection.Either) => // TODO: Because we only have a clustered index on the source vertex ID, we can't filter // the index here. Instead we have to scan all edges and then do the filter. - edgePartition.mapReduceTriplets( - mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr, + edgePartition.aggregateMessages(sendMsg, mergeMsg, tripletFields, (srcId, dstId) => edgePartition.isActive(srcId) || edgePartition.isActive(dstId)) case Some(EdgeDirection.Out) => if (activeFraction < 0.8) { - edgePartition.mapReduceTripletsWithIndex( - mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr, + edgePartition.aggregateMessagesWithIndex(sendMsg, mergeMsg, tripletFields, srcId => edgePartition.isActive(srcId), dstId => true) } else { - edgePartition.mapReduceTriplets( - mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr, + edgePartition.aggregateMessages(sendMsg, mergeMsg, tripletFields, (srcId, dstId) => edgePartition.isActive(srcId)) } case Some(EdgeDirection.In) => - edgePartition.mapReduceTriplets( - mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr, + edgePartition.aggregateMessages(sendMsg, mergeMsg, tripletFields, (srcId, dstId) => edgePartition.isActive(dstId)) case _ => // None - edgePartition.mapReduceTriplets( - mapFunc, reduceFunc, mapUsesSrcAttr, mapUsesDstAttr, + edgePartition.aggregateMessages(sendMsg, mergeMsg, tripletFields, (srcId, dstId) => true) } - }).setName("GraphImpl.mapReduceTriplets - preAgg") + }).setName("GraphImpl.aggregateMessages - preAgg") // do the final reduction reusing the index map - vertices.aggregateUsingIndex(preAgg, reduceFunc) - } // end of mapReduceTriplets + vertices.aggregateUsingIndex(preAgg, mergeMsg) + } override def outerJoinVertices[U: ClassTag, VD2: ClassTag] (other: RDD[(VertexId, U)]) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 257e2f3a36115..e40ae0d615466 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -85,7 +85,7 @@ object PageRank extends Logging { // Associate the degree with each vertex .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) } // Set the weight on the edges based on the degree - .mapTriplets( e => 1.0 / e.srcAttr ) + .mapTriplets( e => 1.0 / e.srcAttr, TripletFields.SrcOnly ) // Set the vertex attributes to the initial pagerank values .mapVertices( (id, attr) => resetProb ) @@ -96,8 +96,8 @@ object PageRank extends Logging { // Compute the outgoing rank contributions of each vertex, perform local preaggregation, and // do the final aggregation at the receiving vertices. Requires a shuffle for aggregation. - val rankUpdates = rankGraph.mapReduceTriplets[Double]( - e => Iterator((e.dstId, e.srcAttr * e.attr)), _ + _) + val rankUpdates = rankGraph.aggregateMessages[Double]( + ctx => ctx.sendToDst(ctx.srcAttr * ctx.attr), _ + _, TripletFields.SrcAndEdge) // Apply the final rank updates to get the new ranks, using join to preserve ranks of vertices // that didn't receive a message. Requires a shuffle for broadcasting updated ranks to the diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index ccd7de537b6e3..f58587e10a820 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -74,9 +74,9 @@ object SVDPlusPlus { var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache() // Calculate initial bias and norm - val t0 = g.mapReduceTriplets( - et => Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))), - (g1: (Long, Double), g2: (Long, Double)) => (g1._1 + g2._1, g1._2 + g2._2)) + val t0 = g.aggregateMessages[(Long, Double)]( + ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) }, + (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2)) g = g.outerJoinVertices(t0) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), @@ -84,15 +84,17 @@ object SVDPlusPlus { (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) } - def mapTrainF(conf: Conf, u: Double) - (et: EdgeTriplet[(DoubleMatrix, DoubleMatrix, Double, Double), Double]) - : Iterator[(VertexId, (DoubleMatrix, DoubleMatrix, Double))] = { - val (usr, itm) = (et.srcAttr, et.dstAttr) + def sendMsgTrainF(conf: Conf, u: Double) + (ctx: EdgeContext[ + (DoubleMatrix, DoubleMatrix, Double, Double), + Double, + (DoubleMatrix, DoubleMatrix, Double)]) { + val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) var pred = u + usr._3 + itm._3 + q.dot(usr._2) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) - val err = et.attr - pred + val err = ctx.attr - pred val updateP = q.mul(err) .subColumnVector(p.mul(conf.gamma7)) .mul(conf.gamma2) @@ -102,16 +104,16 @@ object SVDPlusPlus { val updateY = q.mul(err * usr._4) .subColumnVector(itm._2.mul(conf.gamma7)) .mul(conf.gamma2) - Iterator((et.srcId, (updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)), - (et.dstId, (updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1))) + ctx.sendToSrc((updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)) + ctx.sendToDst((updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1)) } for (i <- 0 until conf.maxIters) { // Phase 1, calculate pu + |N(u)|^(-0.5)*sum(y) for user nodes g.cache() - val t1 = g.mapReduceTriplets( - et => Iterator((et.srcId, et.dstAttr._2)), - (g1: DoubleMatrix, g2: DoubleMatrix) => g1.addColumnVector(g2)) + val t1 = g.aggregateMessages[DoubleMatrix]( + ctx => ctx.sendToSrc(ctx.dstAttr._2), + (g1, g2) => g1.addColumnVector(g2)) g = g.outerJoinVertices(t1) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[DoubleMatrix]) => @@ -121,8 +123,8 @@ object SVDPlusPlus { // Phase 2, update p for user nodes and q, y for item nodes g.cache() - val t2 = g.mapReduceTriplets( - mapTrainF(conf, u), + val t2 = g.aggregateMessages( + sendMsgTrainF(conf, u), (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) => (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3)) g = g.outerJoinVertices(t2) { @@ -135,20 +137,18 @@ object SVDPlusPlus { } // calculate error on training set - def mapTestF(conf: Conf, u: Double) - (et: EdgeTriplet[(DoubleMatrix, DoubleMatrix, Double, Double), Double]) - : Iterator[(VertexId, Double)] = - { - val (usr, itm) = (et.srcAttr, et.dstAttr) + def sendMsgTestF(conf: Conf, u: Double) + (ctx: EdgeContext[(DoubleMatrix, DoubleMatrix, Double, Double), Double, Double]) { + val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) var pred = u + usr._3 + itm._3 + q.dot(usr._2) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) - val err = (et.attr - pred) * (et.attr - pred) - Iterator((et.dstId, err)) + val err = (ctx.attr - pred) * (ctx.attr - pred) + ctx.sendToDst(err) } g.cache() - val t3 = g.mapReduceTriplets(mapTestF(conf, u), (g1: Double, g2: Double) => g1 + g2) + val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _) g = g.outerJoinVertices(t3) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) => if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala index 7c396e6e66a28..daf162085e3e4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala @@ -61,26 +61,27 @@ object TriangleCount { (vid, _, optSet) => optSet.getOrElse(null) } // Edge function computes intersection of smaller vertex with larger vertex - def edgeFunc(et: EdgeTriplet[VertexSet, ED]): Iterator[(VertexId, Int)] = { - assert(et.srcAttr != null) - assert(et.dstAttr != null) - val (smallSet, largeSet) = if (et.srcAttr.size < et.dstAttr.size) { - (et.srcAttr, et.dstAttr) + def edgeFunc(ctx: EdgeContext[VertexSet, ED, Int]) { + assert(ctx.srcAttr != null) + assert(ctx.dstAttr != null) + val (smallSet, largeSet) = if (ctx.srcAttr.size < ctx.dstAttr.size) { + (ctx.srcAttr, ctx.dstAttr) } else { - (et.dstAttr, et.srcAttr) + (ctx.dstAttr, ctx.srcAttr) } val iter = smallSet.iterator var counter: Int = 0 while (iter.hasNext) { val vid = iter.next() - if (vid != et.srcId && vid != et.dstId && largeSet.contains(vid)) { + if (vid != ctx.srcId && vid != ctx.dstId && largeSet.contains(vid)) { counter += 1 } } - Iterator((et.srcId, counter), (et.dstId, counter)) + ctx.sendToSrc(counter) + ctx.sendToDst(counter) } // compute the intersection along edges - val counters: VertexRDD[Int] = setGraph.mapReduceTriplets(edgeFunc, _ + _) + val counters: VertexRDD[Int] = setGraph.aggregateMessages(edgeFunc, _ + _) // Merge counters with the graph and divide by two since each triangle is counted twice g.outerJoinVertices(counters) { (vid, _, optCounter: Option[Int]) => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 697afef29029c..df773db6e4326 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -318,6 +318,21 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } + test("aggregateMessages") { + withSpark { sc => + val n = 5 + val agg = starGraph(sc, n).aggregateMessages[String]( + ctx => { + if (ctx.dstAttr != null) { + throw new Exception( + "expected ctx.dstAttr to be null due to TripletFields, but it was " + ctx.dstAttr) + } + ctx.sendToDst(ctx.srcAttr) + }, _ + _, TripletFields.SrcOnly) + assert(agg.collect().toSet === (1 to n).map(x => (x: VertexId, "v")).toSet) + } + } + test("outerJoinVertices") { withSpark { sc => val n = 5 From f5b65d0695594781324c3ddb9c41ec53a476ac95 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Tue, 11 Nov 2014 20:33:40 -0800 Subject: [PATCH 7/7] Address @rxin comments on apache/spark#3054 and apache/spark#3100 --- .../org/apache/spark/graphx/EdgeContext.scala | 2 +- .../scala/org/apache/spark/graphx/Graph.scala | 93 +++++++-- .../apache/spark/graphx/TripletFields.java | 51 +++++ .../apache/spark/graphx/TripletFields.scala | 59 ------ .../spark/graphx/impl/EdgePartition.scala | 194 ++++++++++++------ .../graphx/impl/EdgePartitionBuilder.scala | 7 +- .../graphx/impl/EdgeTripletIterator.scala | 57 ----- .../apache/spark/graphx/impl/GraphImpl.scala | 40 ++-- .../graphx/impl/EdgePartitionSuite.scala | 11 +- .../impl/EdgeTripletIteratorSuite.scala | 37 ---- 10 files changed, 295 insertions(+), 256 deletions(-) create mode 100644 graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java delete mode 100644 graphx/src/main/scala/org/apache/spark/graphx/TripletFields.scala delete mode 100644 graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala delete mode 100644 graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala index ad85376cec8ac..f70715fca6eea 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala @@ -21,7 +21,7 @@ package org.apache.spark.graphx * Represents an edge along with its neighboring vertices and allows sending messages along the * edge. Used in [[Graph#aggregateMessages]]. */ -trait EdgeContext[VD, ED, A] { +abstract class EdgeContext[VD, ED, A] { /** The vertex id of the edge's source vertex. */ def srcId: VertexId /** The vertex id of the edge's destination vertex. */ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index c0c7ca19d3b76..e0ba9403ba75b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -185,6 +185,33 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab def mapEdges[ED2: ClassTag](map: (PartitionID, Iterator[Edge[ED]]) => Iterator[ED2]) : Graph[VD, ED2] + /** + * Transforms each edge attribute using the map function, passing it the adjacent vertex + * attributes as well. If adjacent vertex values are not required, + * consider using `mapEdges` instead. + * + * @note This does not change the structure of the + * graph or modify the values of this graph. As a consequence + * the underlying index structures can be reused. + * + * @param map the function from an edge object to a new edge value. + * + * @tparam ED2 the new edge data type + * + * @example This function might be used to initialize edge + * attributes based on the attributes associated with each vertex. + * {{{ + * val rawGraph: Graph[Int, Int] = someLoadFunction() + * val graph = rawGraph.mapTriplets[Int]( edge => + * edge.src.data - edge.dst.data) + * }}} + * + */ + def mapTriplets[ED2: ClassTag]( + map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = { + mapTriplets((pid, iter) => iter.map(map), TripletFields.All) + } + /** * Transforms each edge attribute using the map function, passing it the adjacent vertex * attributes as well. If adjacent vertex values are not required, @@ -211,7 +238,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab */ def mapTriplets[ED2: ClassTag]( map: EdgeTriplet[VD, ED] => ED2, - tripletFields: TripletFields = TripletFields.All): Graph[VD, ED2] = { + tripletFields: TripletFields): Graph[VD, ED2] = { mapTriplets((pid, iter) => iter.map(map), tripletFields) } @@ -305,13 +332,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * be commutative and associative and is used to combine the output * of the map phase * - * @param activeSetOpt optionally, a set of "active" vertices and a direction of edges to - * consider when running `mapFunc`. If the direction is `In`, `mapFunc` will only be run on - * edges with destination in the active set. If the direction is `Out`, - * `mapFunc` will only be run on edges originating from vertices in the active set. If the - * direction is `Either`, `mapFunc` will be run on edges with *either* vertex in the active set - * . If the direction is `Both`, `mapFunc` will be run on edges with *both* vertices in the - * active set. The active set must have the same index as the graph's vertices. + * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if + * desired. This is done by specifying a set of "active" vertices and an edge direction. The + * `sendMsg` function will then run only on edges connected to active vertices by edges in the + * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with + * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges + * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be + * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg` + * will be run on edges with *both* vertices in the active set. The active set must have the + * same index as the graph's vertices. * * @example We can use this function to compute the in-degree of each * vertex @@ -349,15 +378,6 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * combiner should be commutative and associative. * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the * `sendMsg` function. If not all fields are needed, specifying this can improve performance. - * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if - * desired. This is done by specifying a set of "active" vertices and an edge direction. The - * `sendMsg` function will then run on only edges connected to active vertices by edges in the - * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with - * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges - * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be - * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg` - * will be run on edges with *both* vertices in the active set. The active set must have the - * same index as the graph's vertices. * * @example We can use this function to compute the in-degree of each * vertex @@ -377,8 +397,43 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab def aggregateMessages[A: ClassTag]( sendMsg: EdgeContext[VD, ED, A] => Unit, mergeMsg: (A, A) => A, - tripletFields: TripletFields = TripletFields.All, - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None) + tripletFields: TripletFields = TripletFields.All) + : VertexRDD[A] = { + aggregateMessagesWithActiveSet(sendMsg, mergeMsg, tripletFields, None) + } + + /** + * Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied + * `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be + * sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages + * destined to the same vertex. + * + * This variant can take an active set to restrict the computation and is intended for internal + * use only. + * + * @tparam A the type of message to be sent to each vertex + * + * @param sendMsg runs on each edge, sending messages to neighboring vertices using the + * [[EdgeContext]]. + * @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This + * combiner should be commutative and associative. + * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the + * `sendMsg` function. If not all fields are needed, specifying this can improve performance. + * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if + * desired. This is done by specifying a set of "active" vertices and an edge direction. The + * `sendMsg` function will then run on only edges connected to active vertices by edges in the + * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with + * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges + * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be + * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg` + * will be run on edges with *both* vertices in the active set. The active set must have the + * same index as the graph's vertices. + */ + private[graphx] def aggregateMessagesWithActiveSet[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]) : VertexRDD[A] /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java new file mode 100644 index 0000000000000..34df4b7ee7a06 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx; + +import java.io.Serializable; + +/** + * Represents a subset of the fields of an [[EdgeTriplet]] or [[EdgeContext]]. This allows the + * system to populate only those fields for efficiency. + */ +public class TripletFields implements Serializable { + public final boolean useSrc; + public final boolean useDst; + public final boolean useEdge; + + public TripletFields() { + this(true, true, true); + } + + public TripletFields(boolean useSrc, boolean useDst, boolean useEdge) { + this.useSrc = useSrc; + this.useDst = useDst; + this.useEdge = useEdge; + } + + public static final TripletFields None = new TripletFields(false, false, false); + public static final TripletFields EdgeOnly = new TripletFields(false, false, true); + public static final TripletFields SrcOnly = new TripletFields(true, false, false); + public static final TripletFields DstOnly = new TripletFields(false, true, false); + public static final TripletFields SrcDstOnly = new TripletFields(true, true, false); + public static final TripletFields SrcAndEdge = new TripletFields(true, false, true); + public static final TripletFields Src = SrcAndEdge; + public static final TripletFields DstAndEdge = new TripletFields(false, true, true); + public static final TripletFields Dst = DstAndEdge; + public static final TripletFields All = new TripletFields(true, true, true); +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.scala b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.scala deleted file mode 100644 index e92e2763a0c06..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx - -/** - * Represents a subset of the fields of an [[EdgeTriplet]] or [[EdgeContext]]. This allows the - * system to populate only those fields for efficiency. - */ -class TripletFields private ( - val useSrc: Boolean, - val useDst: Boolean, - val useEdge: Boolean) - extends Serializable { - private def this() = this(true, true, true) -} - -/** - * Exposes all possible [[TripletFields]] objects. - */ -object TripletFields { - final val None = new TripletFields(useSrc = false, useDst = false, useEdge = false) - final val EdgeOnly = new TripletFields(useSrc = false, useDst = false, useEdge = true) - final val SrcOnly = new TripletFields(useSrc = true, useDst = false, useEdge = false) - final val DstOnly = new TripletFields(useSrc = false, useDst = true, useEdge = false) - final val SrcDstOnly = new TripletFields(useSrc = true, useDst = true, useEdge = false) - final val SrcAndEdge = new TripletFields(useSrc = true, useDst = false, useEdge = true) - final val Src = SrcAndEdge - final val DstAndEdge = new TripletFields(useSrc = false, useDst = true, useEdge = true) - final val Dst = DstAndEdge - final val All = new TripletFields(useSrc = true, useDst = true, useEdge = true) - - /** Returns the appropriate [[TripletFields]] object. */ - private[graphx] def apply(useSrc: Boolean, useDst: Boolean, useEdge: Boolean) = - (useSrc, useDst, useEdge) match { - case (false, false, false) => TripletFields.None - case (false, false, true) => EdgeOnly - case (true, false, false) => SrcOnly - case (false, true, false) => DstOnly - case (true, true, false) => SrcDstOnly - case (true, false, true) => SrcAndEdge - case (false, true, true) => DstAndEdge - case (true, true, true) => All - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index 86ee3923151b4..78d8ac24b5271 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -24,9 +24,17 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.util.collection.BitSet /** - * A collection of edges stored in columnar format, along with any vertex attributes referenced. The - * edges are stored in 3 large columnar arrays (src, dst, attribute). The arrays are clustered by - * src. There is an optional active vertex set for filtering computation on the edges. + * A collection of edges, along with referenced vertex attributes and an optional active vertex set + * for filtering computation on the edges. + * + * The edges are stored in columnar format in `localSrcIds`, `localDstIds`, and `data`. All + * referenced global vertex ids are mapped to a compact set of local vertex ids according to the + * `global2local` map. Each local vertex id is a valid index into `vertexAttrs`, which stores the + * corresponding vertex attribute, and `local2global`, which stores the reverse mapping to global + * vertex id. The global vertex ids that are active are optionally stored in `activeSet`. + * + * The edges are clustered by source vertex id, and the mapping from global vertex id to the index + * of the corresponding edge cluster is stored in `index`. * * @tparam ED the edge attribute type * @tparam VD the vertex attribute type @@ -46,15 +54,17 @@ import org.apache.spark.util.collection.BitSet private[graphx] class EdgePartition[ @specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassTag, VD: ClassTag]( - val localSrcIds: Array[Int] = null, - val localDstIds: Array[Int] = null, - val data: Array[ED] = null, - val index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int] = null, - val global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int] = null, - val local2global: Array[VertexId] = null, - val vertexAttrs: Array[VD] = null, - val activeSet: Option[VertexSet] = None - ) extends Serializable { + localSrcIds: Array[Int], + localDstIds: Array[Int], + data: Array[ED], + index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int], + global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int], + local2global: Array[VertexId], + vertexAttrs: Array[VD], + activeSet: Option[VertexSet]) + extends Serializable { + + private def this() = this(null, null, null, null, null, null, null, null) /** Return a new `EdgePartition` with the specified edge data. */ def withData[ED2: ClassTag](data: Array[ED2]): EdgePartition[ED2, VD] = { @@ -85,16 +95,18 @@ class EdgePartition[ } /** Return a new `EdgePartition` without any locally cached vertex attributes. */ - def clearVertices[VD2: ClassTag](): EdgePartition[ED, VD2] = { + def withoutVertexAttributes[VD2: ClassTag](): EdgePartition[ED, VD2] = { val newVertexAttrs = new Array[VD2](vertexAttrs.length) new EdgePartition( localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs, activeSet) } - private def srcIds(pos: Int): VertexId = local2global(localSrcIds(pos)) + @inline private def srcIds(pos: Int): VertexId = local2global(localSrcIds(pos)) - private def dstIds(pos: Int): VertexId = local2global(localDstIds(pos)) + @inline private def dstIds(pos: Int): VertexId = local2global(localDstIds(pos)) + + @inline private def attrs(pos: Int): ED = data(pos) /** Look up vid in activeSet, throwing an exception if it is None. */ def isActive(vid: VertexId): Boolean = { @@ -285,7 +297,7 @@ class EdgePartition[ if (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) == dstId) { // ... run `f` on the matching edge builder.add(srcId, dstId, localSrcIds(i), localDstIds(i), - f(srcId, dstId, this.data(i), other.data(j))) + f(srcId, dstId, this.data(i), other.attrs(j))) } } i += 1 @@ -332,27 +344,53 @@ class EdgePartition[ * It is safe to keep references to the objects from this iterator. */ def tripletIterator( - includeSrc: Boolean = true, includeDst: Boolean = true): Iterator[EdgeTriplet[VD, ED]] = { - new EdgeTripletIterator(this, includeSrc, includeDst) + includeSrc: Boolean = true, includeDst: Boolean = true) + : Iterator[EdgeTriplet[VD, ED]] = new Iterator[EdgeTriplet[VD, ED]] { + private[this] var pos = 0 + + override def hasNext: Boolean = pos < EdgePartition.this.size + + override def next() = { + val triplet = new EdgeTriplet[VD, ED] + val localSrcId = localSrcIds(pos) + val localDstId = localDstIds(pos) + triplet.srcId = local2global(localSrcId) + triplet.dstId = local2global(localDstId) + if (includeSrc) { + triplet.srcAttr = vertexAttrs(localSrcId) + } + if (includeDst) { + triplet.dstAttr = vertexAttrs(localDstId) + } + triplet.attr = data(pos) + pos += 1 + triplet + } } /** * Send messages along edges and aggregate them at the receiving vertices. Implemented by scanning - * all edges sequentially and filtering them with `idPred`. + * all edges sequentially. * * @param sendMsg generates messages to neighboring vertices of an edge * @param mergeMsg the combiner applied to messages destined to the same vertex - * @param sendMsgUsesSrcAttr whether or not `mapFunc` uses the edge's source vertex attribute - * @param sendMsgUsesDstAttr whether or not `mapFunc` uses the edge's destination vertex attribute - * @param idPred a predicate to filter edges based on their source and destination vertex ids + * @param tripletFields which triplet fields `sendMsg` uses + * @param srcMustBeActive if true, edges will only be considered if their source vertex is in the + * active set + * @param dstMustBeActive if true, edges will only be considered if their destination vertex is in + * the active set + * @param maySatisfyEither if true, only one vertex need be in the active set for an edge to be + * considered * * @return iterator aggregated messages keyed by the receiving vertex id */ - def aggregateMessages[A: ClassTag]( + def aggregateMessagesEdgeScan[A: ClassTag]( sendMsg: EdgeContext[VD, ED, A] => Unit, mergeMsg: (A, A) => A, tripletFields: TripletFields, - idPred: (VertexId, VertexId) => Boolean): Iterator[(VertexId, A)] = { + srcMustBeActive: Boolean, + dstMustBeActive: Boolean, + maySatisfyEither: Boolean): Iterator[(VertexId, A)] = { val aggregates = new Array[A](vertexAttrs.length) val bitset = new BitSet(vertexAttrs.length) @@ -363,14 +401,14 @@ class EdgePartition[ val srcId = local2global(localSrcId) val localDstId = localDstIds(i) val dstId = local2global(localDstId) - if (idPred(srcId, dstId)) { - ctx.localSrcId = localSrcId - ctx.localDstId = localDstId - ctx.srcId = srcId - ctx.dstId = dstId - ctx.attr = data(i) - if (tripletFields.useSrc) { ctx.srcAttr = vertexAttrs(localSrcId) } - if (tripletFields.useDst) { ctx.dstAttr = vertexAttrs(localDstId) } + val srcIsActive = !srcMustBeActive || isActive(srcId) + val dstIsActive = !dstMustBeActive || isActive(dstId) + val edgeIsActive = + if (maySatisfyEither) srcIsActive || dstIsActive else srcIsActive && dstIsActive + if (edgeIsActive) { + val srcAttr = if (tripletFields.useSrc) vertexAttrs(localSrcId) else null.asInstanceOf[VD] + val dstAttr = if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD] + ctx.set(srcId, dstId, localSrcId, localDstId, srcAttr, dstAttr, data(i)) sendMsg(ctx) } i += 1 @@ -381,22 +419,27 @@ class EdgePartition[ /** * Send messages along edges and aggregate them at the receiving vertices. Implemented by - * filtering the source vertex index with `srcIdPred`, then scanning edge clusters and filtering - * with `dstIdPred`. Both `srcIdPred` and `dstIdPred` must match for an edge to run. + * filtering the source vertex index, then scanning each edge cluster. * * @param sendMsg generates messages to neighboring vertices of an edge * @param mergeMsg the combiner applied to messages destined to the same vertex - * @param srcIdPred a predicate to filter edges based on their source vertex id - * @param dstIdPred a predicate to filter edges based on their destination vertex id + * @param tripletFields which triplet fields `sendMsg` uses + * @param srcMustBeActive if true, edges will only be considered if their source vertex is in the + * active set + * @param dstMustBeActive if true, edges will only be considered if their destination vertex is in + * the active set + * @param maySatisfyEither if true, only one vertex need be in the active set for an edge to be + * considered * * @return iterator aggregated messages keyed by the receiving vertex id */ - def aggregateMessagesWithIndex[A: ClassTag]( + def aggregateMessagesIndexScan[A: ClassTag]( sendMsg: EdgeContext[VD, ED, A] => Unit, mergeMsg: (A, A) => A, tripletFields: TripletFields, - srcIdPred: VertexId => Boolean, - dstIdPred: VertexId => Boolean): Iterator[(VertexId, A)] = { + srcMustBeActive: Boolean, + dstMustBeActive: Boolean, + maySatisfyEither: Boolean): Iterator[(VertexId, A)] = { val aggregates = new Array[A](vertexAttrs.length) val bitset = new BitSet(vertexAttrs.length) @@ -405,19 +448,22 @@ class EdgePartition[ val clusterSrcId = cluster._1 val clusterPos = cluster._2 val clusterLocalSrcId = localSrcIds(clusterPos) - if (srcIdPred(clusterSrcId)) { + val srcIsActive = !srcMustBeActive || isActive(clusterSrcId) + if (srcIsActive || maySatisfyEither) { var pos = clusterPos - ctx.srcId = clusterSrcId - ctx.localSrcId = clusterLocalSrcId - if (tripletFields.useSrc) { ctx.srcAttr = vertexAttrs(clusterLocalSrcId) } + val srcAttr = + if (tripletFields.useSrc) vertexAttrs(clusterLocalSrcId) else null.asInstanceOf[VD] + ctx.setSrcOnly(clusterSrcId, clusterLocalSrcId, srcAttr) while (pos < size && localSrcIds(pos) == clusterLocalSrcId) { val localDstId = localDstIds(pos) val dstId = local2global(localDstId) - if (dstIdPred(dstId)) { - ctx.dstId = dstId - ctx.localDstId = localDstId - ctx.attr = data(pos) - if (tripletFields.useDst) { ctx.dstAttr = vertexAttrs(localDstId) } + val dstIsActive = !dstMustBeActive || isActive(dstId) + val edgeIsActive = + if (maySatisfyEither) srcIsActive || dstIsActive else srcIsActive && dstIsActive + if (edgeIsActive) { + val dstAttr = + if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD] + ctx.setRest(dstId, localDstId, dstAttr, data(pos)) sendMsg(ctx) } pos += 1 @@ -435,23 +481,55 @@ private class AggregatingEdgeContext[VD, ED, A]( bitset: BitSet) extends EdgeContext[VD, ED, A] { - var srcId: VertexId = _ - var dstId: VertexId = _ - var srcAttr: VD = _ - var dstAttr: VD = _ - var attr: ED = _ + private[this] var _srcId: VertexId = _ + private[this] var _dstId: VertexId = _ + private[this] var _localSrcId: Int = _ + private[this] var _localDstId: Int = _ + private[this] var _srcAttr: VD = _ + private[this] var _dstAttr: VD = _ + private[this] var _attr: ED = _ + + def set( + srcId: VertexId, dstId: VertexId, + localSrcId: Int, localDstId: Int, + srcAttr: VD, dstAttr: VD, + attr: ED) { + _srcId = srcId + _dstId = dstId + _localSrcId = localSrcId + _localDstId = localDstId + _srcAttr = srcAttr + _dstAttr = dstAttr + _attr = attr + } + + def setSrcOnly(srcId: VertexId, localSrcId: Int, srcAttr: VD) { + _srcId = srcId + _localSrcId = localSrcId + _srcAttr = srcAttr + } + + def setRest(dstId: VertexId, localDstId: Int, dstAttr: VD, attr: ED) { + _dstId = dstId + _localDstId = localDstId + _dstAttr = dstAttr + _attr = attr + } - var localSrcId: Int = _ - var localDstId: Int = _ + override def srcId = _srcId + override def dstId = _dstId + override def srcAttr = _srcAttr + override def dstAttr = _dstAttr + override def attr = _attr override def sendToSrc(msg: A) { - send(localSrcId, msg) + send(_localSrcId, msg) } override def sendToDst(msg: A) { - send(localDstId, msg) + send(_localDstId, msg) } - private def send(localId: Int, msg: A) { + @inline private def send(localId: Int, msg: A) { if (bitset.get(localId)) { aggregates(localId) = mergeMsg(aggregates(localId), msg) } else { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 95a9dca3d16e7..b0cb0fe47d461 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -29,7 +29,7 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap private[graphx] class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag]( size: Int = 64) { - var edges = new PrimitiveVector[Edge[ED]](size) + private[this] val edges = new PrimitiveVector[Edge[ED]](size) /** Add a new edge to the partition. */ def add(src: VertexId, dst: VertexId, d: ED) { @@ -71,7 +71,8 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla vertexAttrs = new Array[VD](currLocalId + 1) } new EdgePartition( - localSrcIds, localDstIds, data, index, global2local, local2global.trim().array, vertexAttrs) + localSrcIds, localDstIds, data, index, global2local, local2global.trim().array, vertexAttrs, + None) } } @@ -87,7 +88,7 @@ class ExistingEdgePartitionBuilder[ vertexAttrs: Array[VD], activeSet: Option[VertexSet], size: Int = 64) { - var edges = new PrimitiveVector[EdgeWithLocalIds[ED]](size) + private[this] val edges = new PrimitiveVector[EdgeWithLocalIds[ED]](size) /** Add a new edge to the partition. */ def add(src: VertexId, dst: VertexId, localSrc: Int, localDst: Int, d: ED) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala deleted file mode 100644 index a8f829ed20a34..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import scala.reflect.ClassTag - -import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap - -/** - * The Iterator type returned when constructing edge triplets. This could be an anonymous class in - * EdgePartition.tripletIterator, but we name it here explicitly so it is easier to debug / profile. - */ -private[impl] -class EdgeTripletIterator[VD: ClassTag, ED: ClassTag]( - val edgePartition: EdgePartition[ED, VD], - val includeSrc: Boolean, - val includeDst: Boolean) - extends Iterator[EdgeTriplet[VD, ED]] { - - // Current position in the array. - private var pos = 0 - - override def hasNext: Boolean = pos < edgePartition.size - - override def next() = { - val triplet = new EdgeTriplet[VD, ED] - val localSrcId = edgePartition.localSrcIds(pos) - val localDstId = edgePartition.localDstIds(pos) - triplet.srcId = edgePartition.local2global(localSrcId) - triplet.dstId = edgePartition.local2global(localDstId) - if (includeSrc) { - triplet.srcAttr = edgePartition.vertexAttrs(localSrcId) - } - if (includeDst) { - triplet.dstAttr = edgePartition.vertexAttrs(localDstId) - } - triplet.attr = edgePartition.data(pos) - pos += 1 - triplet - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index bcbb22b9100dc..a1fe158b7b490 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -186,16 +186,16 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr") val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr") - val tripletFields = TripletFields(mapUsesSrcAttr, mapUsesDstAttr, useEdge = true) + val tripletFields = new TripletFields(mapUsesSrcAttr, mapUsesDstAttr, true) - aggregateMessages(sendMsg, reduceFunc, tripletFields, activeSetOpt) + aggregateMessagesWithActiveSet(sendMsg, reduceFunc, tripletFields, activeSetOpt) } - override def aggregateMessages[A: ClassTag]( + override def aggregateMessagesWithActiveSet[A: ClassTag]( sendMsg: EdgeContext[VD, ED, A] => Unit, mergeMsg: (A, A) => A, tripletFields: TripletFields, - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = { + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = { vertices.cache() // For each vertex, replicate its attribute only to partitions where it is @@ -217,33 +217,31 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( activeDirectionOpt match { case Some(EdgeDirection.Both) => if (activeFraction < 0.8) { - edgePartition.aggregateMessagesWithIndex(sendMsg, mergeMsg, tripletFields, - srcId => edgePartition.isActive(srcId), - dstId => edgePartition.isActive(dstId)) + edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields, + true, true, false) } else { - edgePartition.aggregateMessages(sendMsg, mergeMsg, tripletFields, - (srcId, dstId) => edgePartition.isActive(srcId) && edgePartition.isActive(dstId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + true, true, false) } case Some(EdgeDirection.Either) => // TODO: Because we only have a clustered index on the source vertex ID, we can't filter // the index here. Instead we have to scan all edges and then do the filter. - edgePartition.aggregateMessages(sendMsg, mergeMsg, tripletFields, - (srcId, dstId) => edgePartition.isActive(srcId) || edgePartition.isActive(dstId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + true, true, true) case Some(EdgeDirection.Out) => if (activeFraction < 0.8) { - edgePartition.aggregateMessagesWithIndex(sendMsg, mergeMsg, tripletFields, - srcId => edgePartition.isActive(srcId), - dstId => true) + edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields, + true, false, false) } else { - edgePartition.aggregateMessages(sendMsg, mergeMsg, tripletFields, - (srcId, dstId) => edgePartition.isActive(srcId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + true, false, false) } case Some(EdgeDirection.In) => - edgePartition.aggregateMessages(sendMsg, mergeMsg, tripletFields, - (srcId, dstId) => edgePartition.isActive(dstId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + false, true, false) case _ => // None - edgePartition.aggregateMessages(sendMsg, mergeMsg, tripletFields, - (srcId, dstId) => true) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + false, false, false) } }).setName("GraphImpl.aggregateMessages - preAgg") @@ -327,7 +325,7 @@ object GraphImpl { vertices: VertexRDD[VD], edges: EdgeRDD[ED, _]): GraphImpl[VD, ED] = { // Convert the vertex partitions in edges to the correct type - val newEdges = edges.mapEdgePartitions((pid, part) => part.clearVertices[VD]) + val newEdges = edges.mapEdgePartitions((pid, part) => part.withoutVertexAttributes[VD]) GraphImpl.fromExistingRDDs(vertices, newEdges) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index c7a59990ce8e7..515f3a9cd02eb 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -102,6 +102,16 @@ class EdgePartitionSuite extends FunSuite { assert(ep.numActives == Some(2)) } + test("tripletIterator") { + val builder = new EdgePartitionBuilder[Int, Int] + builder.add(1, 2, 0) + builder.add(1, 3, 0) + builder.add(1, 4, 0) + val ep = builder.toEdgePartition + val result = ep.tripletIterator().toList.map(et => (et.srcId, et.dstId)) + assert(result === Seq((1, 2), (1, 3), (1, 4))) + } + test("serialization") { val aList = List((0, 1, 1), (1, 0, 2), (1, 2, 3), (5, 4, 4), (5, 5, 5)) val a: EdgePartition[Int, Int] = makeEdgePartition(aList) @@ -113,7 +123,6 @@ class EdgePartitionSuite extends FunSuite { for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) { val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a)) assert(aSer.tripletIterator().toList === a.tripletIterator().toList) - assert(aSer.index != null) } } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala deleted file mode 100644 index 49b2704390fea..0000000000000 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import scala.reflect.ClassTag -import scala.util.Random - -import org.scalatest.FunSuite - -import org.apache.spark.graphx._ - -class EdgeTripletIteratorSuite extends FunSuite { - test("iterator.toList") { - val builder = new EdgePartitionBuilder[Int, Int] - builder.add(1, 2, 0) - builder.add(1, 3, 0) - builder.add(1, 4, 0) - val iter = new EdgeTripletIterator[Int, Int](builder.toEdgePartition, true, true) - val result = iter.toList.map(et => (et.srcId, et.dstId)) - assert(result === Seq((1, 2), (1, 3), (1, 4))) - } -}