Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,24 @@ class LDA private (
*
* This is the parameter to a Dirichlet distribution.
*/
def getDocConcentration: Vector = this.docConcentration
def getAsymmetricDocConcentration: Vector = this.docConcentration

/**
* Concentration parameter (commonly named "alpha") for the prior placed on documents'
* distributions over topics ("theta").
*
* This method assumes the Dirichlet distribution is symmetric and can be described by a single
* [[Double]] parameter. It should fail if docConcentration is asymmetric.
*/
def getDocConcentration: Double = {
val parameter = docConcentration(0)
if (docConcentration.size == 1) {
parameter
} else {
require(docConcentration.toArray.forall(_ == parameter))
parameter
}
}

/**
* Concentration parameter (commonly named "alpha") for the prior placed on documents'
Expand All @@ -106,18 +123,22 @@ class LDA private (
* [[https://github.com/Blei-Lab/onlineldavb]].
*/
def setDocConcentration(docConcentration: Vector): this.type = {
require(docConcentration.size > 0, "docConcentration must have > 0 elements")
this.docConcentration = docConcentration
this
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as you're editing here, can you add a check for docConcentration.length > 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

}

/** Replicates Double to create a symmetric prior */
/** Replicates a [[Double]] docConcentration to create a symmetric prior. */
def setDocConcentration(docConcentration: Double): this.type = {
this.docConcentration = Vectors.dense(docConcentration)
this
}

/** Alias for [[getAsymmetricDocConcentration]] */
def getAsymmetricAlpha: Vector = getAsymmetricDocConcentration

/** Alias for [[getDocConcentration]] */
def getAlpha: Vector = getDocConcentration
def getAlpha: Double = getDocConcentration

/** Alias for [[setDocConcentration()]] */
def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId}
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
Expand Down Expand Up @@ -190,7 +189,8 @@ class LocalLDAModel private[clustering] (
val topics: Matrix,
override val docConcentration: Vector,
override val topicConcentration: Double,
override protected[clustering] val gammaShape: Double) extends LDAModel with Serializable {
override protected[clustering] val gammaShape: Double = 100)
extends LDAModel with Serializable {

override def k: Int = topics.numCols

Expand Down Expand Up @@ -455,8 +455,9 @@ class DistributedLDAModel private[clustering] (
val vocabSize: Int,
override val docConcentration: Vector,
override val topicConcentration: Double,
override protected[clustering] val gammaShape: Double,
private[spark] val iterationTimes: Array[Double]) extends LDAModel {
private[spark] val iterationTimes: Array[Double],
override protected[clustering] val gammaShape: Double = 100)
extends LDAModel {

import LDA._

Expand Down Expand Up @@ -756,7 +757,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)

new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize,
docConcentration, topicConcentration, gammaShape, iterationTimes)
docConcentration, topicConcentration, iterationTimes, gammaShape)
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,8 @@ final class EMLDAOptimizer extends LDAOptimizer {
* Compute bipartite term/doc graph.
*/
override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = {
val docConcentration = lda.getDocConcentration(0)
require({
lda.getDocConcentration.toArray.forall(_ == docConcentration)
}, "EMLDAOptimizer currently only supports symmetric document-topic priors")
// EMLDAOptimizer currently only supports symmetric document-topic priors
val docConcentration = lda.getDocConcentration

val topicConcentration = lda.getTopicConcentration
val k = lda.getK
Expand Down Expand Up @@ -209,11 +207,11 @@ final class EMLDAOptimizer extends LDAOptimizer {
override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
this.graphCheckpointer.deleteAllCheckpoints()
// This assumes gammaShape = 100 in OnlineLDAOptimizer to ensure equivalence in LDAModel.toLocal
// conversion
// The constructor's default arguments assume gammaShape = 100 to ensure equivalence in
// LDAModel.toLocal conversion
new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize,
Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration,
100, iterationTimes)
iterationTimes)
}
}

Expand Down Expand Up @@ -378,18 +376,20 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
this.k = lda.getK
this.corpusSize = docs.count()
this.vocabSize = docs.first()._2.size
this.alpha = if (lda.getDocConcentration.size == 1) {
if (lda.getDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k))
this.alpha = if (lda.getAsymmetricDocConcentration.size == 1) {
if (lda.getAsymmetricDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k))
else {
require(lda.getDocConcentration(0) >= 0, s"all entries in alpha must be >=0, got: $alpha")
Vectors.dense(Array.fill(k)(lda.getDocConcentration(0)))
require(lda.getAsymmetricDocConcentration(0) >= 0,
s"all entries in alpha must be >=0, got: $alpha")
Vectors.dense(Array.fill(k)(lda.getAsymmetricDocConcentration(0)))
}
} else {
require(lda.getDocConcentration.size == k, s"alpha must have length k, got: $alpha")
lda.getDocConcentration.foreachActive { case (_, x) =>
require(lda.getAsymmetricDocConcentration.size == k,
s"alpha must have length k, got: $alpha")
lda.getAsymmetricDocConcentration.foreachActive { case (_, x) =>
require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha")
}
lda.getDocConcentration
lda.getAsymmetricDocConcentration
}
this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration
this.randomGenerator = new Random(lda.getSeed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {

test("setter alias") {
val lda = new LDA().setAlpha(2.0).setBeta(3.0)
assert(lda.getAlpha.toArray.forall(_ === 2.0))
assert(lda.getDocConcentration.toArray.forall(_ === 2.0))
assert(lda.getAsymmetricAlpha.toArray.forall(_ === 2.0))
assert(lda.getAsymmetricDocConcentration.toArray.forall(_ === 2.0))
assert(lda.getBeta === 3.0)
assert(lda.getTopicConcentration === 3.0)
}
Expand Down