Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
*/
package org.apache.spark.mllib.api.python

import scala.collection.JavaConverters
import scala.collection.JavaConverters._

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.SparkContext
import org.apache.spark.mllib.clustering.LDAModel
import org.apache.spark.mllib.linalg.Matrix
import org.apache.spark.mllib.linalg.{Matrix, Vector}
import org.apache.spark.rdd.RDD

/**
* Wrapper around LDAModel to provide helper methods in Python
Expand All @@ -35,11 +37,30 @@ private[python] class LDAModelWrapper(model: LDAModel) {

def describeTopics(maxTermsPerTopic: Int): Array[Byte] = {
val topics = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) =>
val jTerms = JavaConverters.seqAsJavaListConverter(terms).asJava
val jTermWeights = JavaConverters.seqAsJavaListConverter(termWeights).asJava
val jTerms = seqAsJavaListConverter(terms).asJava
val jTermWeights = seqAsJavaListConverter(termWeights).asJava
Array[Any](jTerms, jTermWeights)
}
SerDe.dumps(JavaConverters.seqAsJavaListConverter(topics).asJava)
SerDe.dumps(seqAsJavaListConverter(topics).asJava)
}

def topicDistributions(
data: JavaRDD[java.util.List[Any]]): JavaRDD[Array[Any]] = {

val documents = data.rdd.map(_.asScala.toArray).map { r =>
r(0) match {
case i: java.lang.Integer => (i.toLong, r(1).asInstanceOf[Vector])
case i: java.lang.Long => (i.toLong, r(1).asInstanceOf[Vector])
case _ => throw new IllegalArgumentException("input values contains invalid type value.")
}
}

val distributions = model.topicDistributions(documents)

SerDe.fromTuple2RDD( distributions.map {
case (id, vector) => ( id.toLong, vector.asInstanceOf[ Vector ] )
}.asInstanceOf[ RDD[(Any, Any)] ])

}

def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ abstract class LDAModel private[clustering] extends Saveable {
* The returned RDD may be zipped with the given RDD, where each returned vector
* is a multinomial distribution over topics.
*/
// def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)]
def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)]

}

Expand Down Expand Up @@ -341,7 +341,7 @@ class LocalLDAModel private[spark] (
*/
@Since("1.3.0")
// TODO: declare in LDAModel and override once implemented in DistributedLDAModel
def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = {
override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = {
// Double transpose because dirichletExpectation normalizes by row and we need to normalize
// by topic (columns of lambda)
val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t)
Expand Down Expand Up @@ -777,6 +777,10 @@ class DistributedLDAModel private[clustering] (
JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]])
}

override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this what we want here? It seems having it defined on the parent if half of the children aren't implementing it might be confusing to some users.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@holdenk I'm keen to work on this. definitely agree, but am not sure how else to approach this without implementing the logic for LDA distributed models.

throw new NotImplementedError("Convert to LocalLDAModel or use online optimizer.")
}

/**
* For each document, return the top k weighted topics for that document and their weights.
* @return RDD of (doc ID, topic indices, topic weights)
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,20 @@ def describeTopics(self, maxTermsPerTopic=None):
topics = self.call("describeTopics", maxTermsPerTopic)
return topics

def topicDistributions(self, documents):
"""
Compute the estimated topic distribution for each document.
This is often called 'theta' in the literature.
:param documents:
RDD of document id and features.
:return:
RDD where each row is a tuple of document id and array of
estimated topic distribution for k topics.
"""
if not isinstance(documents, RDD):
raise TypeError("documents should be rdd, got type %s" % type(documents))
return self.call("topicDistributions", documents)

@classmethod
@since('1.5.0')
def load(cls, sc, path):
Expand Down