Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* the underlying index structures can be reused.
*
* @param map the function from an edge object to a new edge value.
* @param mapUsesSrcAttr indicates whether the source vertex attribute should be included in
* the triplet. Setting this to false can improve performance if the source vertex attribute
* is not needed.
* @param mapUsesDstAttr indicates whether the destination vertex attribute should be included in
* the triplet. Setting this to false can improve performance if the destination vertex attribute
* is not needed.
*
* @tparam ED2 the new edge data type
*
Expand All @@ -207,8 +213,10 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* }}}
*
*/
def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
mapTriplets((pid, iter) => iter.map(map))
def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd deprecate the old one and create a new one for the new api.

the old one should keep the old behavior

tripletFields: TripletFields = TripletFields.All)
: Graph[VD, ED2] = {
mapTriplets((pid, iter) => iter.map(map), tripletFields)
}

/**
Expand All @@ -223,12 +231,18 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* the underlying index structures can be reused.
*
* @param map the iterator transform
* @param mapUsesSrcAttr indicates whether the source vertex attribute should be included in
* the triplet. Setting this to false can improve performance if the source vertex attribute
* is not needed.
* @param mapUsesDstAttr indicates whether the destination vertex attribute should be included in
* the triplet. Setting this to false can improve performance if the destination vertex attribute
* is not needed.
*
* @tparam ED2 the new edge data type
*
*/
def mapTriplets[ED2: ClassTag](map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2])
: Graph[VD, ED2]
def mapTriplets[ED2: ClassTag](map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2],
tripletFields: TripletFields): Graph[VD, ED2]

/**
* Reverses all edges in the graph. If this graph contains an edge from a to b then the returned
Expand Down Expand Up @@ -258,7 +272,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
*/
def subgraph(
epred: EdgeTriplet[VD,ED] => Boolean = (x => true),
vpred: (VertexId, VD) => Boolean = ((v, d) => true))
vpred: (VertexId, VD) => Boolean = ((v, d) => true),
tripletFields: TripletFields = TripletFields.All)
: Graph[VD, ED]

/**
Expand Down Expand Up @@ -303,6 +318,12 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* direction is `Either`, `mapFunc` will be run on edges with *either* vertex in the active set
* . If the direction is `Both`, `mapFunc` will be run on edges with *both* vertices in the
* active set. The active set must have the same index as the graph's vertices.
* @param mapUsesSrcAttr indicates whether the source vertex attribute should be included in
* the triplet. Setting this to false can improve performance if the source vertex attribute
* is not needed.
* @param mapUsesDstAttr indicates whether the destination vertex attribute should be included in
* the triplet. Setting this to false can improve performance if the destination vertex attribute
* is not needed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These comments are out of date now that we have tripletFields

*
* @example We can use this function to compute the in-degree of each
* vertex
Expand All @@ -322,12 +343,13 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
def mapReduceTriplets[A: ClassTag](
mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
reduceFunc: (A, A) => A,
tripletFields: TripletFields = TripletFields.All,
activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None)
: VertexRDD[A]

/**
* Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. The
* input table should contain at most one entry for each vertex. If no entry in `other` is
* Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`.
* The input table should contain at most one entry for each vertex. If no entry in `other` is
* provided for a particular vertex in the graph, the map function receives `None`.
*
* @tparam U the type of entry in the table of updates
Expand Down
61 changes: 33 additions & 28 deletions graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,12 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
*/
private def degreesRDD(edgeDirection: EdgeDirection): VertexRDD[Int] = {
if (edgeDirection == EdgeDirection.In) {
graph.mapReduceTriplets(et => Iterator((et.dstId,1)), _ + _)
graph.mapReduceTriplets(et => Iterator((et.dstId,1)), _ + _, TripletFields.None)
} else if (edgeDirection == EdgeDirection.Out) {
graph.mapReduceTriplets(et => Iterator((et.srcId,1)), _ + _)
graph.mapReduceTriplets(et => Iterator((et.srcId,1)), _ + _, TripletFields.None)
} else { // EdgeDirection.Either
graph.mapReduceTriplets(et => Iterator((et.srcId,1), (et.dstId,1)), _ + _)
graph.mapReduceTriplets(et => Iterator((et.srcId,1), (et.dstId,1)), _ + _,
TripletFields.None)
}
}

Expand All @@ -90,16 +91,15 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
if (edgeDirection == EdgeDirection.Either) {
graph.mapReduceTriplets[Array[VertexId]](
mapFunc = et => Iterator((et.srcId, Array(et.dstId)), (et.dstId, Array(et.srcId))),
reduceFunc = _ ++ _
)
reduceFunc = _ ++ _, tripletFields = TripletFields.None)
} else if (edgeDirection == EdgeDirection.Out) {
graph.mapReduceTriplets[Array[VertexId]](
mapFunc = et => Iterator((et.srcId, Array(et.dstId))),
reduceFunc = _ ++ _)
reduceFunc = _ ++ _, tripletFields = TripletFields.None)
} else if (edgeDirection == EdgeDirection.In) {
graph.mapReduceTriplets[Array[VertexId]](
mapFunc = et => Iterator((et.dstId, Array(et.srcId))),
reduceFunc = _ ++ _)
reduceFunc = _ ++ _, tripletFields = TripletFields.None)
} else {
throw new SparkException("It doesn't make sense to collect neighbor ids without a " +
"direction. (EdgeDirection.Both is not supported; use EdgeDirection.Either instead.)")
Expand All @@ -122,22 +122,25 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
* @return the vertex set of neighboring vertex attributes for each vertex
*/
def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = {
val nbrs = graph.mapReduceTriplets[Array[(VertexId,VD)]](
edge => {
val msgToSrc = (edge.srcId, Array((edge.dstId, edge.dstAttr)))
val msgToDst = (edge.dstId, Array((edge.srcId, edge.srcAttr)))
edgeDirection match {
case EdgeDirection.Either => Iterator(msgToSrc, msgToDst)
case EdgeDirection.In => Iterator(msgToDst)
case EdgeDirection.Out => Iterator(msgToSrc)
case EdgeDirection.Both =>
throw new SparkException("collectNeighbors does not support EdgeDirection.Both. Use" +
"EdgeDirection.Either instead.")
}
},
(a, b) => a ++ b)

graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) =>
val nbrs = edgeDirection match {
case EdgeDirection.Either =>
graph.mapReduceTriplets[Array[(VertexId,VD)]](
edge => Iterator((edge.srcId, Array((edge.dstId, edge.dstAttr))),
(edge.dstId, Array((edge.srcId, edge.srcAttr)))),
(a, b) => a ++ b, TripletFields.SrcDstOnly)
case EdgeDirection.In =>
graph.mapReduceTriplets[Array[(VertexId,VD)]](
edge => Iterator((edge.dstId, Array((edge.srcId, edge.srcAttr)))),
(a, b) => a ++ b, TripletFields.SrcOnly)
case EdgeDirection.Out =>
graph.mapReduceTriplets[Array[(VertexId,VD)]](
edge => Iterator((edge.srcId, Array((edge.dstId, edge.dstAttr)))),
(a, b) => a ++ b, TripletFields.DstOnly)
case EdgeDirection.Both =>
throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
"EdgeDirection.Either instead.")
}
graph.vertices.leftJoin(nbrs) { (vid, vdata, nbrsOpt) =>
nbrsOpt.getOrElse(Array.empty[(VertexId, VD)])
}
} // end of collectNeighbor
Expand All @@ -163,15 +166,15 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
graph.mapReduceTriplets[Array[Edge[ED]]](
edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))),
(edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
(a, b) => a ++ b)
(a, b) => a ++ b, TripletFields.EdgeOnly)
case EdgeDirection.In =>
graph.mapReduceTriplets[Array[Edge[ED]]](
edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
(a, b) => a ++ b)
(a, b) => a ++ b, TripletFields.EdgeOnly)
case EdgeDirection.Out =>
graph.mapReduceTriplets[Array[Edge[ED]]](
edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
(a, b) => a ++ b)
(a, b) => a ++ b, TripletFields.EdgeOnly)
case EdgeDirection.Both =>
throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
"EdgeDirection.Either instead.")
Expand Down Expand Up @@ -324,9 +327,11 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
activeDirection: EdgeDirection = EdgeDirection.Either)(
vprog: (VertexId, VD, A) => VD,
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId,A)],
mergeMsg: (A, A) => A)
mergeMsg: (A, A) => A,
tripletFields: TripletFields = TripletFields.All)
: Graph[VD, ED] = {
Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg)
Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg,
tripletFields)
}

/**
Expand Down
6 changes: 4 additions & 2 deletions graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ object Pregel extends Logging {
activeDirection: EdgeDirection = EdgeDirection.Either)
(vprog: (VertexId, VD, A) => VD,
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
mergeMsg: (A, A) => A)
mergeMsg: (A, A) => A,
tripletFields: TripletFields = TripletFields.All)
: Graph[VD, ED] =
{
var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
Expand All @@ -138,7 +139,8 @@ object Pregel extends Logging {
// Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't
// get to send messages. We must cache messages so it can be materialized on the next line,
// allowing us to uncache the previous iteration.
messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDirection))).cache()
messages = g.mapReduceTriplets(sendMsg, mergeMsg, tripletFields,
Some((newVerts, activeDirection))).cache()
// The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This
// hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the
// vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g).
Expand Down
47 changes: 47 additions & 0 deletions graphx/src/main/scala/org/apache/spark/graphx/TripletFields.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.graphx


class TripletFields private(
val useSrc: Boolean,
val useDst: Boolean,
val useEdge: Boolean)
extends Serializable {
/**
* Default triplet fields includes all fields
*/
def this() = this(true, true, true)
}


/**
* A set of [[TripletFields]]s.
*/
object TripletFields {
final val None = new TripletFields(useSrc = false, useDst = false, useEdge = false)
final val EdgeOnly = new TripletFields(useSrc = false, useDst = false, useEdge = true)
final val SrcOnly = new TripletFields(useSrc = true, useDst = false, useEdge = false)
final val DstOnly = new TripletFields(useSrc = false, useDst = true, useEdge = false)
final val SrcDstOnly = new TripletFields(useSrc = true, useDst = true, useEdge = false)
final val SrcAndEdge = new TripletFields(useSrc = true, useDst = false, useEdge = true)
final val Src = SrcAndEdge
final val DstAndEdge = new TripletFields(useSrc = false, useDst = true, useEdge = true)
final val Dst = DstAndEdge
final val All = new TripletFields(useSrc = true, useDst = true, useEdge = true)
}
34 changes: 12 additions & 22 deletions graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,26 +127,26 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
}

override def mapTriplets[ED2: ClassTag](
f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]): Graph[VD, ED2] = {
f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2],
tripletFields: TripletFields): Graph[VD, ED2] = {
vertices.cache()
val mapUsesSrcAttr = accessesVertexAttr(f, "srcAttr")
val mapUsesDstAttr = accessesVertexAttr(f, "dstAttr")
replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr)
replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst)
val newEdges = replicatedVertexView.edges.mapEdgePartitions { (pid, part) =>
part.map(f(pid, part.tripletIterator(mapUsesSrcAttr, mapUsesDstAttr)))
part.map(f(pid, part.tripletIterator(tripletFields.useSrc, tripletFields.useDst)))
}
new GraphImpl(vertices, replicatedVertexView.withEdges(newEdges))
}

override def subgraph(
epred: EdgeTriplet[VD, ED] => Boolean = x => true,
vpred: (VertexId, VD) => Boolean = (a, b) => true): Graph[VD, ED] = {
vpred: (VertexId, VD) => Boolean = (a, b) => true,
tripletFields: TripletFields = TripletFields.All): Graph[VD, ED] = {
vertices.cache()
// Filter the vertices, reusing the partitioner and the index from this graph
val newVerts = vertices.mapVertexPartitions(_.filter(vpred))
// Filter the triplets. We must always upgrade the triplet view fully because vpred always runs
// on both src and dst vertices
replicatedVertexView.upgrade(vertices, true, true)
replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this change is correct - see the comment on the line above.

val newEdges = replicatedVertexView.edges.filter(epred, vpred)
new GraphImpl(newVerts, replicatedVertexView.withEdges(newEdges))
}
Expand All @@ -171,15 +171,13 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
override def mapReduceTriplets[A: ClassTag](
mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
reduceFunc: (A, A) => A,
activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = {
tripletFields: TripletFields,
activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = {

vertices.cache()

// For each vertex, replicate its attribute only to partitions where it is
// in the relevant position in an edge.
val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr")
val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr")
replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr)
replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst)
val view = activeSetOpt match {
case Some((activeSet, _)) =>
replicatedVertexView.withActiveSet(activeSet)
Expand Down Expand Up @@ -220,8 +218,8 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
}

// Scan edges and run the map function
val mapOutputs = edgePartition.upgradeIterator(edgeIter, mapUsesSrcAttr, mapUsesDstAttr)
.flatMap(mapFunc(_))
val mapOutputs = edgePartition.upgradeIterator(edgeIter, tripletFields.useSrc,
tripletFields.useDst).flatMap(mapFunc(_))
// Note: This doesn't allow users to send messages to arbitrary vertices.
edgePartition.vertices.aggregateUsingIndex(mapOutputs, reduceFunc).iterator
}).setName("GraphImpl.mapReduceTriplets - preAgg")
Expand Down Expand Up @@ -251,14 +249,6 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
}
}

/** Test whether the closure accesses the the attribute with name `attrName`. */
private def accessesVertexAttr(closure: AnyRef, attrName: String): Boolean = {
try {
BytecodeUtils.invokedMethod(closure, classOf[EdgeTriplet[VD, ED]], attrName)
} catch {
case _: ClassNotFoundException => true // if we don't know, be conservative
}
}
} // end of class GraphImpl


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,6 @@ object ConnectedComponents {
Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Either)(
vprog = (id, attr, msg) => math.min(attr, msg),
sendMsg = sendMessage,
mergeMsg = (a, b) => math.min(a, b))
mergeMsg = (a, b) => math.min(a, b), TripletFields.All)
} // end of connectedComponents
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ object PageRank extends Logging {
// Associate the degree with each vertex
.outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) }
// Set the weight on the edges based on the degree
.mapTriplets( e => 1.0 / e.srcAttr )
.mapTriplets( e => 1.0 / e.srcAttr, TripletFields.SrcOnly )
// Set the vertex attributes to the initial pagerank values
.mapVertices( (id, attr) => resetProb )

Expand All @@ -97,7 +97,7 @@ object PageRank extends Logging {
// Compute the outgoing rank contributions of each vertex, perform local preaggregation, and
// do the final aggregation at the receiving vertices. Requires a shuffle for aggregation.
val rankUpdates = rankGraph.mapReduceTriplets[Double](
e => Iterator((e.dstId, e.srcAttr * e.attr)), _ + _)
e => Iterator((e.dstId, e.srcAttr * e.attr)), _ + _, TripletFields.SrcAndEdge)

// Apply the final rank updates to get the new ranks, using join to preserve ranks of vertices
// that didn't receive a message. Requires a shuffle for broadcasting updated ranks to the
Expand Down Expand Up @@ -171,7 +171,7 @@ object PageRank extends Logging {

// Execute a dynamic version of Pregel.
Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)(
vertexProgram, sendMessage, messageCombiner)
vertexProgram, sendMessage, messageCombiner, TripletFields.SrcAndEdge)
.mapVertices((vid, attr) => attr._1)
} // end of deltaPageRank
}
Loading