diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala index 63282eee6e65..dd02c3a82d02 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala @@ -16,11 +16,13 @@ */ package org.apache.spark.mllib.api.python -import scala.collection.JavaConverters +import scala.collection.JavaConverters._ +import org.apache.spark.api.java.JavaRDD import org.apache.spark.SparkContext import org.apache.spark.mllib.clustering.LDAModel -import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.rdd.RDD /** * Wrapper around LDAModel to provide helper methods in Python @@ -35,11 +37,30 @@ private[python] class LDAModelWrapper(model: LDAModel) { def describeTopics(maxTermsPerTopic: Int): Array[Byte] = { val topics = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) => - val jTerms = JavaConverters.seqAsJavaListConverter(terms).asJava - val jTermWeights = JavaConverters.seqAsJavaListConverter(termWeights).asJava + val jTerms = seqAsJavaListConverter(terms).asJava + val jTermWeights = seqAsJavaListConverter(termWeights).asJava Array[Any](jTerms, jTermWeights) } - SerDe.dumps(JavaConverters.seqAsJavaListConverter(topics).asJava) + SerDe.dumps(seqAsJavaListConverter(topics).asJava) + } + + def topicDistributions( + data: JavaRDD[java.util.List[Any]]): JavaRDD[Array[Any]] = { + + val documents = data.rdd.map(_.asScala.toArray).map { r => + r(0) match { + case i: java.lang.Integer => (i.toLong, r(1).asInstanceOf[Vector]) + case i: java.lang.Long => (i.toLong, r(1).asInstanceOf[Vector]) + case _ => throw new IllegalArgumentException("input values contains invalid type value.") + } + } + + val distributions = model.topicDistributions(documents) + + SerDe.fromTuple2RDD( distributions.map { + case (id, vector) => ( id.toLong, vector.asInstanceOf[ Vector ] ) + }.asInstanceOf[ RDD[(Any, Any)] ]) + } def save(sc: SparkContext, path: String): Unit = model.save(sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 90d8a558f10d..c7b760c3d19c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -176,7 +176,7 @@ abstract class LDAModel private[clustering] extends Saveable { * The returned RDD may be zipped with the given RDD, where each returned vector * is a multinomial distribution over topics. */ - // def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] + def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] } @@ -341,7 +341,7 @@ class LocalLDAModel private[spark] ( */ @Since("1.3.0") // TODO: declare in LDAModel and override once implemented in DistributedLDAModel - def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = { + override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = { // Double transpose because dirichletExpectation normalizes by row and we need to normalize // by topic (columns of lambda) val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) @@ -777,6 +777,10 @@ class DistributedLDAModel private[clustering] ( JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) } + override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = { + throw new NotImplementedError("Convert to LocalLDAModel or use online optimizer.") + } + /** * For each document, return the top k weighted topics for that document and their weights. * @return RDD of (doc ID, topic indices, topic weights) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index c8c3c42774f2..ebdacaee01c7 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -972,6 +972,20 @@ def describeTopics(self, maxTermsPerTopic=None): topics = self.call("describeTopics", maxTermsPerTopic) return topics + def topicDistributions(self, documents): + """ + Compute the estimated topic distribution for each document. + This is often called 'theta' in the literature. + :param documents: + RDD of document id and features. + :return: + RDD where each row is a tuple of document id and array of + estimated topic distribution for k topics. + """ + if not isinstance(documents, RDD): + raise TypeError("documents should be rdd, got type %s" % type(documents)) + return self.call("topicDistributions", documents) + @classmethod @since('1.5.0') def load(cls, sc, path):