diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index fa4b891754c40..dc11d9d54e32e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -195,6 +195,12 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * the underlying index structures can be reused. * * @param map the function from an edge object to a new edge value. + * @param mapUsesSrcAttr indicates whether the source vertex attribute should be included in + * the triplet. Setting this to false can improve performance if the source vertex attribute + * is not needed. + * @param mapUsesDstAttr indicates whether the destination vertex attribute should be included in + * the triplet. Setting this to false can improve performance if the destination vertex attribute + * is not needed. * * @tparam ED2 the new edge data type * @@ -207,8 +213,10 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * }}} * */ - def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = { - mapTriplets((pid, iter) => iter.map(map)) + def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2, + tripletFields: TripletFields = TripletFields.All) + : Graph[VD, ED2] = { + mapTriplets((pid, iter) => iter.map(map), tripletFields) } /** @@ -223,12 +231,18 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * the underlying index structures can be reused. * * @param map the iterator transform + * @param mapUsesSrcAttr indicates whether the source vertex attribute should be included in + * the triplet. Setting this to false can improve performance if the source vertex attribute + * is not needed. + * @param mapUsesDstAttr indicates whether the destination vertex attribute should be included in + * the triplet. Setting this to false can improve performance if the destination vertex attribute + * is not needed. * * @tparam ED2 the new edge data type * */ - def mapTriplets[ED2: ClassTag](map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]) - : Graph[VD, ED2] + def mapTriplets[ED2: ClassTag](map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2], + tripletFields: TripletFields): Graph[VD, ED2] /** * Reverses all edges in the graph. If this graph contains an edge from a to b then the returned @@ -258,7 +272,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab */ def subgraph( epred: EdgeTriplet[VD,ED] => Boolean = (x => true), - vpred: (VertexId, VD) => Boolean = ((v, d) => true)) + vpred: (VertexId, VD) => Boolean = ((v, d) => true), + tripletFields: TripletFields = TripletFields.All) : Graph[VD, ED] /** @@ -303,6 +318,12 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * direction is `Either`, `mapFunc` will be run on edges with *either* vertex in the active set * . If the direction is `Both`, `mapFunc` will be run on edges with *both* vertices in the * active set. The active set must have the same index as the graph's vertices. + * @param mapUsesSrcAttr indicates whether the source vertex attribute should be included in + * the triplet. Setting this to false can improve performance if the source vertex attribute + * is not needed. + * @param mapUsesDstAttr indicates whether the destination vertex attribute should be included in + * the triplet. Setting this to false can improve performance if the destination vertex attribute + * is not needed. * * @example We can use this function to compute the in-degree of each * vertex @@ -322,12 +343,13 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab def mapReduceTriplets[A: ClassTag]( mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], reduceFunc: (A, A) => A, + tripletFields: TripletFields = TripletFields.All, activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None) : VertexRDD[A] /** - * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. The - * input table should contain at most one entry for each vertex. If no entry in `other` is + * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. + * The input table should contain at most one entry for each vertex. If no entry in `other` is * provided for a particular vertex in the graph, the map function receives `None`. * * @tparam U the type of entry in the table of updates diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index d0dd45dba618e..84f9bd03108cc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -69,11 +69,12 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali */ private def degreesRDD(edgeDirection: EdgeDirection): VertexRDD[Int] = { if (edgeDirection == EdgeDirection.In) { - graph.mapReduceTriplets(et => Iterator((et.dstId,1)), _ + _) + graph.mapReduceTriplets(et => Iterator((et.dstId,1)), _ + _, TripletFields.None) } else if (edgeDirection == EdgeDirection.Out) { - graph.mapReduceTriplets(et => Iterator((et.srcId,1)), _ + _) + graph.mapReduceTriplets(et => Iterator((et.srcId,1)), _ + _, TripletFields.None) } else { // EdgeDirection.Either - graph.mapReduceTriplets(et => Iterator((et.srcId,1), (et.dstId,1)), _ + _) + graph.mapReduceTriplets(et => Iterator((et.srcId,1), (et.dstId,1)), _ + _, + TripletFields.None) } } @@ -90,16 +91,15 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali if (edgeDirection == EdgeDirection.Either) { graph.mapReduceTriplets[Array[VertexId]]( mapFunc = et => Iterator((et.srcId, Array(et.dstId)), (et.dstId, Array(et.srcId))), - reduceFunc = _ ++ _ - ) + reduceFunc = _ ++ _, tripletFields = TripletFields.None) } else if (edgeDirection == EdgeDirection.Out) { graph.mapReduceTriplets[Array[VertexId]]( mapFunc = et => Iterator((et.srcId, Array(et.dstId))), - reduceFunc = _ ++ _) + reduceFunc = _ ++ _, tripletFields = TripletFields.None) } else if (edgeDirection == EdgeDirection.In) { graph.mapReduceTriplets[Array[VertexId]]( mapFunc = et => Iterator((et.dstId, Array(et.srcId))), - reduceFunc = _ ++ _) + reduceFunc = _ ++ _, tripletFields = TripletFields.None) } else { throw new SparkException("It doesn't make sense to collect neighbor ids without a " + "direction. (EdgeDirection.Both is not supported; use EdgeDirection.Either instead.)") @@ -122,22 +122,25 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * @return the vertex set of neighboring vertex attributes for each vertex */ def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = { - val nbrs = graph.mapReduceTriplets[Array[(VertexId,VD)]]( - edge => { - val msgToSrc = (edge.srcId, Array((edge.dstId, edge.dstAttr))) - val msgToDst = (edge.dstId, Array((edge.srcId, edge.srcAttr))) - edgeDirection match { - case EdgeDirection.Either => Iterator(msgToSrc, msgToDst) - case EdgeDirection.In => Iterator(msgToDst) - case EdgeDirection.Out => Iterator(msgToSrc) - case EdgeDirection.Both => - throw new SparkException("collectNeighbors does not support EdgeDirection.Both. Use" + - "EdgeDirection.Either instead.") - } - }, - (a, b) => a ++ b) - - graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) => + val nbrs = edgeDirection match { + case EdgeDirection.Either => + graph.mapReduceTriplets[Array[(VertexId,VD)]]( + edge => Iterator((edge.srcId, Array((edge.dstId, edge.dstAttr))), + (edge.dstId, Array((edge.srcId, edge.srcAttr)))), + (a, b) => a ++ b, TripletFields.SrcDstOnly) + case EdgeDirection.In => + graph.mapReduceTriplets[Array[(VertexId,VD)]]( + edge => Iterator((edge.dstId, Array((edge.srcId, edge.srcAttr)))), + (a, b) => a ++ b, TripletFields.SrcOnly) + case EdgeDirection.Out => + graph.mapReduceTriplets[Array[(VertexId,VD)]]( + edge => Iterator((edge.srcId, Array((edge.dstId, edge.dstAttr)))), + (a, b) => a ++ b, TripletFields.DstOnly) + case EdgeDirection.Both => + throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" + + "EdgeDirection.Either instead.") + } + graph.vertices.leftJoin(nbrs) { (vid, vdata, nbrsOpt) => nbrsOpt.getOrElse(Array.empty[(VertexId, VD)]) } } // end of collectNeighbor @@ -163,15 +166,15 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali graph.mapReduceTriplets[Array[Edge[ED]]]( edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))), (edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), - (a, b) => a ++ b) + (a, b) => a ++ b, TripletFields.EdgeOnly) case EdgeDirection.In => graph.mapReduceTriplets[Array[Edge[ED]]]( edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), - (a, b) => a ++ b) + (a, b) => a ++ b, TripletFields.EdgeOnly) case EdgeDirection.Out => graph.mapReduceTriplets[Array[Edge[ED]]]( edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), - (a, b) => a ++ b) + (a, b) => a ++ b, TripletFields.EdgeOnly) case EdgeDirection.Both => throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" + "EdgeDirection.Either instead.") @@ -324,9 +327,11 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali activeDirection: EdgeDirection = EdgeDirection.Either)( vprog: (VertexId, VD, A) => VD, sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId,A)], - mergeMsg: (A, A) => A) + mergeMsg: (A, A) => A, + tripletFields: TripletFields = TripletFields.All) : Graph[VD, ED] = { - Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg) + Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg, + tripletFields) } /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 5e55620147df8..2818b5a31711e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -116,7 +116,8 @@ object Pregel extends Logging { activeDirection: EdgeDirection = EdgeDirection.Either) (vprog: (VertexId, VD, A) => VD, sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], - mergeMsg: (A, A) => A) + mergeMsg: (A, A) => A, + tripletFields: TripletFields = TripletFields.All) : Graph[VD, ED] = { var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() @@ -138,7 +139,8 @@ object Pregel extends Logging { // Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't // get to send messages. We must cache messages so it can be materialized on the next line, // allowing us to uncache the previous iteration. - messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDirection))).cache() + messages = g.mapReduceTriplets(sendMsg, mergeMsg, tripletFields, + Some((newVerts, activeDirection))).cache() // The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g). diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.scala b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.scala new file mode 100644 index 0000000000000..420109f24cfe4 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx + + +class TripletFields private( + val useSrc: Boolean, + val useDst: Boolean, + val useEdge: Boolean) + extends Serializable { + /** + * Default triplet fields includes all fields + */ + def this() = this(true, true, true) +} + + +/** + * A set of [[TripletFields]]s. + */ +object TripletFields { + final val None = new TripletFields(useSrc = false, useDst = false, useEdge = false) + final val EdgeOnly = new TripletFields(useSrc = false, useDst = false, useEdge = true) + final val SrcOnly = new TripletFields(useSrc = true, useDst = false, useEdge = false) + final val DstOnly = new TripletFields(useSrc = false, useDst = true, useEdge = false) + final val SrcDstOnly = new TripletFields(useSrc = true, useDst = true, useEdge = false) + final val SrcAndEdge = new TripletFields(useSrc = true, useDst = false, useEdge = true) + final val Src = SrcAndEdge + final val DstAndEdge = new TripletFields(useSrc = false, useDst = true, useEdge = true) + final val Dst = DstAndEdge + final val All = new TripletFields(useSrc = true, useDst = true, useEdge = true) +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 33f35cfb69a26..e5965ac5c6fde 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -127,26 +127,26 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( } override def mapTriplets[ED2: ClassTag]( - f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]): Graph[VD, ED2] = { + f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2], + tripletFields: TripletFields): Graph[VD, ED2] = { vertices.cache() - val mapUsesSrcAttr = accessesVertexAttr(f, "srcAttr") - val mapUsesDstAttr = accessesVertexAttr(f, "dstAttr") - replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr) + replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst) val newEdges = replicatedVertexView.edges.mapEdgePartitions { (pid, part) => - part.map(f(pid, part.tripletIterator(mapUsesSrcAttr, mapUsesDstAttr))) + part.map(f(pid, part.tripletIterator(tripletFields.useSrc, tripletFields.useDst))) } new GraphImpl(vertices, replicatedVertexView.withEdges(newEdges)) } override def subgraph( epred: EdgeTriplet[VD, ED] => Boolean = x => true, - vpred: (VertexId, VD) => Boolean = (a, b) => true): Graph[VD, ED] = { + vpred: (VertexId, VD) => Boolean = (a, b) => true, + tripletFields: TripletFields = TripletFields.All): Graph[VD, ED] = { vertices.cache() // Filter the vertices, reusing the partitioner and the index from this graph val newVerts = vertices.mapVertexPartitions(_.filter(vpred)) // Filter the triplets. We must always upgrade the triplet view fully because vpred always runs // on both src and dst vertices - replicatedVertexView.upgrade(vertices, true, true) + replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst) val newEdges = replicatedVertexView.edges.filter(epred, vpred) new GraphImpl(newVerts, replicatedVertexView.withEdges(newEdges)) } @@ -171,15 +171,13 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( override def mapReduceTriplets[A: ClassTag]( mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], reduceFunc: (A, A) => A, - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = { + tripletFields: TripletFields, + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = { vertices.cache() - // For each vertex, replicate its attribute only to partitions where it is // in the relevant position in an edge. - val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr") - val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr") - replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr) + replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst) val view = activeSetOpt match { case Some((activeSet, _)) => replicatedVertexView.withActiveSet(activeSet) @@ -220,8 +218,8 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( } // Scan edges and run the map function - val mapOutputs = edgePartition.upgradeIterator(edgeIter, mapUsesSrcAttr, mapUsesDstAttr) - .flatMap(mapFunc(_)) + val mapOutputs = edgePartition.upgradeIterator(edgeIter, tripletFields.useSrc, + tripletFields.useDst).flatMap(mapFunc(_)) // Note: This doesn't allow users to send messages to arbitrary vertices. edgePartition.vertices.aggregateUsingIndex(mapOutputs, reduceFunc).iterator }).setName("GraphImpl.mapReduceTriplets - preAgg") @@ -251,14 +249,6 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( } } - /** Test whether the closure accesses the the attribute with name `attrName`. */ - private def accessesVertexAttr(closure: AnyRef, attrName: String): Boolean = { - try { - BytecodeUtils.invokedMethod(closure, classOf[EdgeTriplet[VD, ED]], attrName) - } catch { - case _: ClassNotFoundException => true // if we don't know, be conservative - } - } } // end of class GraphImpl diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala index e2f6cc138958e..3a922de3ae45b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala @@ -50,6 +50,6 @@ object ConnectedComponents { Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Either)( vprog = (id, attr, msg) => math.min(attr, msg), sendMsg = sendMessage, - mergeMsg = (a, b) => math.min(a, b)) + mergeMsg = (a, b) => math.min(a, b), TripletFields.All) } // end of connectedComponents } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 257e2f3a36115..91037e6552cb4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -85,7 +85,7 @@ object PageRank extends Logging { // Associate the degree with each vertex .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) } // Set the weight on the edges based on the degree - .mapTriplets( e => 1.0 / e.srcAttr ) + .mapTriplets( e => 1.0 / e.srcAttr, TripletFields.SrcOnly ) // Set the vertex attributes to the initial pagerank values .mapVertices( (id, attr) => resetProb ) @@ -97,7 +97,7 @@ object PageRank extends Logging { // Compute the outgoing rank contributions of each vertex, perform local preaggregation, and // do the final aggregation at the receiving vertices. Requires a shuffle for aggregation. val rankUpdates = rankGraph.mapReduceTriplets[Double]( - e => Iterator((e.dstId, e.srcAttr * e.attr)), _ + _) + e => Iterator((e.dstId, e.srcAttr * e.attr)), _ + _, TripletFields.SrcAndEdge) // Apply the final rank updates to get the new ranks, using join to preserve ranks of vertices // that didn't receive a message. Requires a shuffle for broadcasting updated ranks to the @@ -171,7 +171,7 @@ object PageRank extends Logging { // Execute a dynamic version of Pregel. Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)( - vertexProgram, sendMessage, messageCombiner) + vertexProgram, sendMessage, messageCombiner, TripletFields.SrcAndEdge) .mapVertices((vid, attr) => attr._1) } // end of deltaPageRank } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 6506bac73d71c..0473880f84906 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -298,7 +298,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { throw new Exception("map ran on edge with dst vid %d, which is odd".format(et.dstId)) } Iterator((et.srcId, 1)) - }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect.toSet + }, (a: Int, b: Int) => a + b, TripletFields.All, Some((active, EdgeDirection.In))).collect.toSet assert(numEvenNeighbors === (1 to n).map(x => (x: VertexId, n / 2)).toSet) // outerJoinVertices followed by mapReduceTriplets(activeSetOpt) @@ -312,7 +312,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { throw new Exception("map ran on edge with src vid %d, which is even".format(et.dstId)) } Iterator((et.dstId, 1)) - }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect.toSet + }, (a: Int, b: Int) => a + b, TripletFields.All, Some(changed, EdgeDirection.Out)).collect.toSet assert(numOddNeighbors === (2 to n by 2).map(x => (x: VertexId, 1)).toSet) }