Skip to content

Commit ba3aa8f

Browse files
viiryamengxr
authored andcommitted
[SPARK-5714][Mllib] Refactor initial step of LDA to remove redundant operations
The `initialState` of LDA performs several RDD operations that looks redundant. This pr tries to simplify these operations. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #4501 from viirya/sim_lda and squashes the following commits: 4870fe4 [Liang-Chi Hsieh] For comments. 9af1487 [Liang-Chi Hsieh] Refactor initial step of LDA to remove redundant operations. (cherry picked from commit f86a89a) Signed-off-by: Xiangrui Meng <meng@databricks.com>
1 parent 63af90c commit ba3aa8f

File tree

1 file changed

+13
-24
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/clustering

1 file changed

+13
-24
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -450,34 +450,23 @@ private[clustering] object LDA {
450450

451451
// Create vertices.
452452
// Initially, we use random soft assignments of tokens to topics (random gamma).
453-
val edgesWithGamma: RDD[(Edge[TokenCount], TopicCounts)] =
454-
edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
455-
val random = new Random(partIndex + randomSeed)
456-
partEdges.map { edge =>
457-
// Create a random gamma_{wjk}
458-
(edge, normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0))
453+
def createVertices(): RDD[(VertexId, TopicCounts)] = {
454+
val verticesTMP: RDD[(VertexId, TopicCounts)] =
455+
edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
456+
val random = new Random(partIndex + randomSeed)
457+
partEdges.flatMap { edge =>
458+
val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)
459+
val sum = gamma * edge.attr
460+
Seq((edge.srcId, sum), (edge.dstId, sum))
461+
}
459462
}
460-
}
461-
def createVertices(sendToWhere: Edge[TokenCount] => VertexId): RDD[(VertexId, TopicCounts)] = {
462-
val verticesTMP: RDD[(VertexId, (TokenCount, TopicCounts))] =
463-
edgesWithGamma.map { case (edge, gamma: TopicCounts) =>
464-
(sendToWhere(edge), (edge.attr, gamma))
465-
}
466-
verticesTMP.aggregateByKey(BDV.zeros[Double](k))(
467-
(sum, t) => {
468-
brzAxpy(t._1, t._2, sum)
469-
sum
470-
},
471-
(sum0, sum1) => {
472-
sum0 += sum1
473-
}
474-
)
463+
verticesTMP.reduceByKey(_ + _)
475464
}
476-
val docVertices = createVertices(_.srcId)
477-
val termVertices = createVertices(_.dstId)
465+
466+
val docTermVertices = createVertices()
478467

479468
// Partition such that edges are grouped by document
480-
val graph = Graph(docVertices ++ termVertices, edges)
469+
val graph = Graph(docTermVertices, edges)
481470
.partitionBy(PartitionStrategy.EdgePartition1D)
482471

483472
new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval)

0 commit comments

Comments
 (0)