Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 94 additions & 86 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -314,31 +314,31 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* Model fitted by [[LDA]].
*
* @param vocabSize Vocabulary size (number of terms or terms in the vocabulary)
* @param oldLocalModel Underlying spark.mllib model.
* If this model was produced by Online LDA, then this is the
* only model representation.
* If this model was produced by EM, then this local
* representation may be built lazily.
* @param sqlContext Used to construct local DataFrames for returning query results
*/
@Since("1.6.0")
@Experimental
class LDAModel private[ml] (
sealed abstract class LDAModel private[ml] (
@Since("1.6.0") override val uid: String,
@Since("1.6.0") val vocabSize: Int,
@Since("1.6.0") protected var oldLocalModel: Option[OldLocalLDAModel],
@Since("1.6.0") @transient protected val sqlContext: SQLContext)
extends Model[LDAModel] with LDAParams with Logging {

/** Returns underlying spark.mllib model */
// NOTE to developers:
// This abstraction should contain all important functionality for basic LDA usage.
// Specializations of this class can contain expert-only functionality.

/**
* Underlying spark.mllib model.
* If this model was produced by Online LDA, then this is the only model representation.
* If this model was produced by EM, then this local representation may be built lazily.
*/
@Since("1.6.0")
protected def getModel: OldLDAModel = oldLocalModel match {
case Some(m) => m
case None =>
// Should never happen.
throw new RuntimeException("LDAModel required local model format," +
" but the underlying model is missing.")
}
protected def oldLocalModel: OldLocalLDAModel

/** Returns underlying spark.mllib model, which may be local or distributed */
@Since("1.6.0")
protected def getModel: OldLDAModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we warn that this (any any method that uses this) will trigger a collect in DistributedLDAModel (if so, does the doc belong here or in DistributedLDAModel), or can we expect users of DistributedLDAModel understand that already?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not trigger a collect. It's oldLocalModel which can trigger a collect.

I'll add some more warnings in places where it can happen in the public API.


/**
* The features for LDA should be a [[Vector]] representing the word counts in a document.
Expand All @@ -352,16 +352,17 @@ class LDAModel private[ml] (
@Since("1.6.0")
def setSeed(value: Long): this.type = set(seed, value)

@Since("1.6.0")
override def copy(extra: ParamMap): LDAModel = {
val copied = new LDAModel(uid, vocabSize, oldLocalModel, sqlContext)
copyValues(copied, extra).setParent(parent)
}

/**
* Transforms the input dataset.
*
* WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]]
* is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
* This implementation may be changed in the future.
*/
@Since("1.6.0")
override def transform(dataset: DataFrame): DataFrame = {
if ($(topicDistributionCol).nonEmpty) {
val t = udf(oldLocalModel.get.getTopicDistributionMethod(sqlContext.sparkContext))
val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext))
dataset.withColumn($(topicDistributionCol), t(col($(featuresCol))))
} else {
logWarning("LDAModel.transform was called without any output columns. Set an output column" +
Expand All @@ -388,56 +389,50 @@ class LDAModel private[ml] (
* This is a matrix of size vocabSize x k, where each column is a topic.
* No guarantees are given about the ordering of the topics.
*
* WARNING: If this model is actually a [[DistributedLDAModel]] instance from EM,
* then this method could involve collecting a large amount of data to the driver
* (on the order of vocabSize x k).
* WARNING: If this model is actually a [[DistributedLDAModel]] instance produced by
* the Expectation-Maximization ("em") [[optimizer]], then this method could involve
* collecting a large amount of data to the driver (on the order of vocabSize x k).
*/
@Since("1.6.0")
def topicsMatrix: Matrix = getModel.topicsMatrix
def topicsMatrix: Matrix = oldLocalModel.topicsMatrix

/** Indicates whether this instance is of type [[DistributedLDAModel]] */
@Since("1.6.0")
def isDistributed: Boolean = false
def isDistributed: Boolean

/**
* Calculates a lower bound on the log likelihood of the entire corpus.
*
* See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
*
* WARNING: If this model was learned via a [[DistributedLDAModel]], this involves collecting
* a large [[topicsMatrix]] to the driver. This implementation may be changed in the
* future.
* WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]]
* is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
* This implementation may be changed in the future.
*
* @param dataset test corpus to use for calculating log likelihood
* @return variational lower bound on the log likelihood of the entire corpus
*/
@Since("1.6.0")
def logLikelihood(dataset: DataFrame): Double = oldLocalModel match {
case Some(m) =>
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
m.logLikelihood(oldDataset)
case None =>
// Should never happen.
throw new RuntimeException("LocalLDAModel.logLikelihood was called," +
" but the underlying model is missing.")
def logLikelihood(dataset: DataFrame): Double = {
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
oldLocalModel.logLikelihood(oldDataset)
}

/**
* Calculate an upper bound bound on perplexity. (Lower is better.)
* See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
*
* WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]]
* is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
* This implementation may be changed in the future.
*
* @param dataset test corpus to use for calculating perplexity
* @return Variational upper bound on log perplexity per token.
*/
@Since("1.6.0")
def logPerplexity(dataset: DataFrame): Double = oldLocalModel match {
case Some(m) =>
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
m.logPerplexity(oldDataset)
case None =>
// Should never happen.
throw new RuntimeException("LocalLDAModel.logPerplexity was called," +
" but the underlying model is missing.")
def logPerplexity(dataset: DataFrame): Double = {
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
oldLocalModel.logPerplexity(oldDataset)
}

/**
Expand Down Expand Up @@ -468,70 +463,83 @@ class LDAModel private[ml] (
/**
* :: Experimental ::
*
* Distributed model fitted by [[LDA]] using Expectation-Maximization (EM).
* Local (non-distributed) model fitted by [[LDA]].
*
* This model stores the inferred topics only; it does not store info about the training dataset.
*/
@Since("1.6.0")
@Experimental
class LocalLDAModel private[ml] (
uid: String,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We annotated the constructor args to the private constructor for LDAModel but not here; why?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the Since annotations? It's because those constructor args are also vals which will appear in the public API.

vocabSize: Int,
@Since("1.6.0") override protected val oldLocalModel: OldLocalLDAModel,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does a protected val in a private constructor need a Since annotation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have to clarify with @mengxr
It's technically a public API since the class is not final.

sqlContext: SQLContext)
extends LDAModel(uid, vocabSize, sqlContext) {

@Since("1.6.0")
override def copy(extra: ParamMap): LocalLDAModel = {
val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext)
copyValues(copied, extra).setParent(parent).asInstanceOf[LocalLDAModel]
}

override protected def getModel: OldLDAModel = oldLocalModel

@Since("1.6.0")
override def isDistributed: Boolean = false
}


/**
* :: Experimental ::
*
* Distributed model fitted by [[LDA]].
* This type of model is currently only produced by Expectation-Maximization (EM).
*
* This model stores the inferred topics, the full training dataset, and the topic distribution
* for each training document.
*
* @param oldLocalModelOption Used to implement [[oldLocalModel]] as a lazy val, but keeping
* [[copy()]] cheap.
*/
@Since("1.6.0")
@Experimental
class DistributedLDAModel private[ml] (
uid: String,
vocabSize: Int,
private val oldDistributedModel: OldDistributedLDAModel,
sqlContext: SQLContext)
extends LDAModel(uid, vocabSize, None, sqlContext) {
sqlContext: SQLContext,
private var oldLocalModelOption: Option[OldLocalLDAModel])
extends LDAModel(uid, vocabSize, sqlContext) {

override protected def oldLocalModel: OldLocalLDAModel = {
if (oldLocalModelOption.isEmpty) {
oldLocalModelOption = Some(oldDistributedModel.toLocal)
}
oldLocalModelOption.get
}

override protected def getModel: OldLDAModel = oldDistributedModel

/**
* Convert this distributed model to a local representation. This discards info about the
* training dataset.
*
* WARNING: This involves collecting a large [[topicsMatrix]] to the driver.
*/
@Since("1.6.0")
def toLocal: LDAModel = {
if (oldLocalModel.isEmpty) {
oldLocalModel = Some(oldDistributedModel.toLocal)
}
new LDAModel(uid, vocabSize, oldLocalModel, sqlContext)
}

@Since("1.6.0")
override protected def getModel: OldLDAModel = oldDistributedModel
def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext)

@Since("1.6.0")
override def copy(extra: ParamMap): DistributedLDAModel = {
val copied = new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext)
if (oldLocalModel.nonEmpty) copied.oldLocalModel = oldLocalModel
val copied =
new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext, oldLocalModelOption)
copyValues(copied, extra).setParent(parent)
copied
}

@Since("1.6.0")
override def topicsMatrix: Matrix = {
if (oldLocalModel.isEmpty) {
oldLocalModel = Some(oldDistributedModel.toLocal)
}
super.topicsMatrix
}

@Since("1.6.0")
override def isDistributed: Boolean = true

@Since("1.6.0")
override def logLikelihood(dataset: DataFrame): Double = {
if (oldLocalModel.isEmpty) {
oldLocalModel = Some(oldDistributedModel.toLocal)
}
super.logLikelihood(dataset)
}

@Since("1.6.0")
override def logPerplexity(dataset: DataFrame): Double = {
if (oldLocalModel.isEmpty) {
oldLocalModel = Some(oldDistributedModel.toLocal)
}
super.logPerplexity(dataset)
}

/**
* Log likelihood of the observed tokens in the training set,
* given the current parameter estimates:
Expand Down Expand Up @@ -673,9 +681,9 @@ class LDA @Since("1.6.0") (
val oldModel = oldLDA.run(oldData)
val newModel = oldModel match {
case m: OldLocalLDAModel =>
new LDAModel(uid, m.vocabSize, Some(m), dataset.sqlContext)
new LocalLDAModel(uid, m.vocabSize, m, dataset.sqlContext)
case m: OldDistributedLDAModel =>
new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext)
new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None)
}
copyValues(newModel).setParent(this)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {

MLTestingUtils.checkCopy(model)

assert(!model.isInstanceOf[DistributedLDAModel])
assert(model.isInstanceOf[LocalLDAModel])
assert(model.vocabSize === vocabSize)
assert(model.estimatedDocConcentration.size === k)
assert(model.topicsMatrix.numRows === vocabSize)
Expand Down Expand Up @@ -210,7 +210,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.isDistributed)

val localModel = model.toLocal
assert(!localModel.isInstanceOf[DistributedLDAModel])
assert(localModel.isInstanceOf[LocalLDAModel])

// training logLikelihood, logPrior
val ll = model.trainingLogLikelihood
Expand Down