diff --git a/ml/src/main/scala/com/github/cloudml/zen/ml/clustering/LDA.scala b/ml/src/main/scala/com/github/cloudml/zen/ml/clustering/LDA.scala index 5bfeb0fe..a4807945 100644 --- a/ml/src/main/scala/com/github/cloudml/zen/ml/clustering/LDA.scala +++ b/ml/src/main/scala/com/github/cloudml/zen/ml/clustering/LDA.scala @@ -32,7 +32,7 @@ import org.apache.spark.storage.StorageLevel class LDA(@transient var edges: EdgeRDDImpl[TA, _], - @transient var verts: VertexRDDImpl[TC], + @transient var verts: VertexRDD[TC], val numTopics: Int, val numTerms: Int, val numDocs: Long, diff --git a/ml/src/main/scala/com/github/cloudml/zen/ml/clustering/algorithm/LDAAlgorithm.scala b/ml/src/main/scala/com/github/cloudml/zen/ml/clustering/algorithm/LDAAlgorithm.scala index c8add004..2bf4bf20 100644 --- a/ml/src/main/scala/com/github/cloudml/zen/ml/clustering/algorithm/LDAAlgorithm.scala +++ b/ml/src/main/scala/com/github/cloudml/zen/ml/clustering/algorithm/LDAAlgorithm.scala @@ -25,6 +25,7 @@ import com.github.cloudml.zen.ml.clustering.LDADefines._ import com.github.cloudml.zen.ml.clustering.{LDALogLikelihood, LDAPerplexity} import com.github.cloudml.zen.ml.util.BVDecompressor import com.github.cloudml.zen.ml.util.Concurrent._ +import org.apache.spark.graphx2.VertexRDD import org.apache.spark.graphx2.impl.{ShippableVertexPartition => VertPartition, _} import scala.collection.JavaConversions._ @@ -74,7 +75,7 @@ abstract class LDAAlgorithm(numTopics: Int, } def sampleGraph(edges: EdgeRDDImpl[TA, _], - verts: VertexRDDImpl[TC], + verts: VertexRDD[TC], topicCounters: BDV[Count], seed: Int, sampIter: Int, @@ -98,7 +99,7 @@ abstract class LDAAlgorithm(numTopics: Int, } def updateVertexCounters(edges: EdgeRDDImpl[TA, Int], - verts: VertexRDDImpl[TC]): VertexRDDImpl[TC] = { + verts: VertexRDD[TC]): VertexRDD[TC] = { val shippedCounters = edges.partitionsRDD.mapPartitions(_.flatMap { case (_, ep) => countPartition(ep) }).partitionBy(verts.partitioner.get) @@ -112,7 +113,7 @@ abstract class LDAAlgorithm(numTopics: Int, } def calcPerplexity(edges: EdgeRDDImpl[TA, _], - verts: VertexRDDImpl[TC], + verts: VertexRDD[TC], topicCounters: BDV[Count], numTokens: Long, numTerms: Int, @@ -131,7 +132,7 @@ abstract class LDAAlgorithm(numTopics: Int, new LDAPerplexity(pplx, wpplx, dpplx) } - def calcLogLikelihood(verts: VertexRDDImpl[TC], + def calcLogLikelihood(verts: VertexRDD[TC], topicCounters: BDV[Count], numTokens: Long, numDocs: Long, @@ -151,7 +152,7 @@ abstract class LDAAlgorithm(numTopics: Int, } def refreshEdgeAssociations(edges: EdgeRDDImpl[TA, _], - verts: VertexRDDImpl[TC]): EdgeRDDImpl[TA, Nvk] = { + verts: VertexRDD[TC]): EdgeRDDImpl[TA, Nvk] = { val shippedVerts = verts.partitionsRDD.mapPartitions(_.flatMap { vp => val rt = vp.routingTable val index = vp.index @@ -201,7 +202,7 @@ abstract class LDAAlgorithm(numTopics: Int, edges.withPartitionsRDD(partRDD) } - def collectTopicCounters(verts: VertexRDDImpl[TC]): BDV[Count] = { + def collectTopicCounters(verts: VertexRDD[TC]): BDV[Count] = { verts.partitionsRDD.mapPartitions(_.map { vp => val totalSize = vp.capacity val index = vp.index diff --git a/ml/src/main/scala/org/apache/spark/graphx2/impl/EdgePartition.scala b/ml/src/main/scala/org/apache/spark/graphx2/impl/EdgePartition.scala index 404b6d4e..9a3bcf00 100644 --- a/ml/src/main/scala/org/apache/spark/graphx2/impl/EdgePartition.scala +++ b/ml/src/main/scala/org/apache/spark/graphx2/impl/EdgePartition.scala @@ -109,11 +109,11 @@ class EdgePartition[ activeSet) } - @inline def srcIds(pos: Int): VertexId = local2global(localSrcIds(pos)) + @inline private def srcIds(pos: Int): VertexId = local2global(localSrcIds(pos)) - @inline def dstIds(pos: Int): VertexId = local2global(localDstIds(pos)) + @inline private def dstIds(pos: Int): VertexId = local2global(localDstIds(pos)) - @inline def attrs(pos: Int): ED = data(pos) + @inline private def attrs(pos: Int): ED = data(pos) /** Look up vid in activeSet, throwing an exception if it is None. */ def isActive(vid: VertexId): Boolean = { diff --git a/ml/src/main/scala/org/apache/spark/graphx2/impl/VertexRDDImpl.scala b/ml/src/main/scala/org/apache/spark/graphx2/impl/VertexRDDImpl.scala index 7894bed0..d78e1587 100644 --- a/ml/src/main/scala/org/apache/spark/graphx2/impl/VertexRDDImpl.scala +++ b/ml/src/main/scala/org/apache/spark/graphx2/impl/VertexRDDImpl.scala @@ -233,7 +233,7 @@ class VertexRDDImpl[VD] ( } override def withPartitionsRDD[VD2: ClassTag]( - partitionsRDD: RDD[ShippableVertexPartition[VD2]]): VertexRDDImpl[VD2] = { + partitionsRDD: RDD[ShippableVertexPartition[VD2]]): VertexRDD[VD2] = { new VertexRDDImpl(partitionsRDD, this.targetStorageLevel) }