-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-11712] [ML] Make spark.ml LDAModel be abstract #9678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2545c89
3c97e41
9f39384
29d08a8
b3e9341
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -314,31 +314,31 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM | |
| * Model fitted by [[LDA]]. | ||
| * | ||
| * @param vocabSize Vocabulary size (number of terms or terms in the vocabulary) | ||
| * @param oldLocalModel Underlying spark.mllib model. | ||
| * If this model was produced by Online LDA, then this is the | ||
| * only model representation. | ||
| * If this model was produced by EM, then this local | ||
| * representation may be built lazily. | ||
| * @param sqlContext Used to construct local DataFrames for returning query results | ||
| */ | ||
| @Since("1.6.0") | ||
| @Experimental | ||
| class LDAModel private[ml] ( | ||
| sealed abstract class LDAModel private[ml] ( | ||
| @Since("1.6.0") override val uid: String, | ||
| @Since("1.6.0") val vocabSize: Int, | ||
| @Since("1.6.0") protected var oldLocalModel: Option[OldLocalLDAModel], | ||
| @Since("1.6.0") @transient protected val sqlContext: SQLContext) | ||
| extends Model[LDAModel] with LDAParams with Logging { | ||
|
|
||
| /** Returns underlying spark.mllib model */ | ||
| // NOTE to developers: | ||
| // This abstraction should contain all important functionality for basic LDA usage. | ||
| // Specializations of this class can contain expert-only functionality. | ||
|
|
||
| /** | ||
| * Underlying spark.mllib model. | ||
| * If this model was produced by Online LDA, then this is the only model representation. | ||
| * If this model was produced by EM, then this local representation may be built lazily. | ||
| */ | ||
| @Since("1.6.0") | ||
| protected def getModel: OldLDAModel = oldLocalModel match { | ||
| case Some(m) => m | ||
| case None => | ||
| // Should never happen. | ||
| throw new RuntimeException("LDAModel required local model format," + | ||
| " but the underlying model is missing.") | ||
| } | ||
| protected def oldLocalModel: OldLocalLDAModel | ||
|
|
||
| /** Returns underlying spark.mllib model, which may be local or distributed */ | ||
| @Since("1.6.0") | ||
| protected def getModel: OldLDAModel | ||
|
|
||
| /** | ||
| * The features for LDA should be a [[Vector]] representing the word counts in a document. | ||
|
|
@@ -352,16 +352,17 @@ class LDAModel private[ml] ( | |
| @Since("1.6.0") | ||
| def setSeed(value: Long): this.type = set(seed, value) | ||
|
|
||
| @Since("1.6.0") | ||
| override def copy(extra: ParamMap): LDAModel = { | ||
| val copied = new LDAModel(uid, vocabSize, oldLocalModel, sqlContext) | ||
| copyValues(copied, extra).setParent(parent) | ||
| } | ||
|
|
||
| /** | ||
| * Transforms the input dataset. | ||
| * | ||
| * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]] | ||
| * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. | ||
| * This implementation may be changed in the future. | ||
| */ | ||
| @Since("1.6.0") | ||
| override def transform(dataset: DataFrame): DataFrame = { | ||
| if ($(topicDistributionCol).nonEmpty) { | ||
| val t = udf(oldLocalModel.get.getTopicDistributionMethod(sqlContext.sparkContext)) | ||
| val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext)) | ||
| dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))) | ||
| } else { | ||
| logWarning("LDAModel.transform was called without any output columns. Set an output column" + | ||
|
|
@@ -388,56 +389,50 @@ class LDAModel private[ml] ( | |
| * This is a matrix of size vocabSize x k, where each column is a topic. | ||
| * No guarantees are given about the ordering of the topics. | ||
| * | ||
| * WARNING: If this model is actually a [[DistributedLDAModel]] instance from EM, | ||
| * then this method could involve collecting a large amount of data to the driver | ||
| * (on the order of vocabSize x k). | ||
| * WARNING: If this model is actually a [[DistributedLDAModel]] instance produced by | ||
| * the Expectation-Maximization ("em") [[optimizer]], then this method could involve | ||
| * collecting a large amount of data to the driver (on the order of vocabSize x k). | ||
| */ | ||
| @Since("1.6.0") | ||
| def topicsMatrix: Matrix = getModel.topicsMatrix | ||
| def topicsMatrix: Matrix = oldLocalModel.topicsMatrix | ||
|
|
||
| /** Indicates whether this instance is of type [[DistributedLDAModel]] */ | ||
| @Since("1.6.0") | ||
| def isDistributed: Boolean = false | ||
| def isDistributed: Boolean | ||
|
|
||
| /** | ||
| * Calculates a lower bound on the log likelihood of the entire corpus. | ||
| * | ||
| * See Equation (16) in the Online LDA paper (Hoffman et al., 2010). | ||
| * | ||
| * WARNING: If this model was learned via a [[DistributedLDAModel]], this involves collecting | ||
| * a large [[topicsMatrix]] to the driver. This implementation may be changed in the | ||
| * future. | ||
| * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]] | ||
| * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. | ||
| * This implementation may be changed in the future. | ||
| * | ||
| * @param dataset test corpus to use for calculating log likelihood | ||
| * @return variational lower bound on the log likelihood of the entire corpus | ||
| */ | ||
| @Since("1.6.0") | ||
| def logLikelihood(dataset: DataFrame): Double = oldLocalModel match { | ||
| case Some(m) => | ||
| val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) | ||
| m.logLikelihood(oldDataset) | ||
| case None => | ||
| // Should never happen. | ||
| throw new RuntimeException("LocalLDAModel.logLikelihood was called," + | ||
| " but the underlying model is missing.") | ||
| def logLikelihood(dataset: DataFrame): Double = { | ||
| val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) | ||
| oldLocalModel.logLikelihood(oldDataset) | ||
| } | ||
|
|
||
| /** | ||
| * Calculate an upper bound bound on perplexity. (Lower is better.) | ||
| * See Equation (16) in the Online LDA paper (Hoffman et al., 2010). | ||
| * | ||
| * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]] | ||
| * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. | ||
| * This implementation may be changed in the future. | ||
| * | ||
| * @param dataset test corpus to use for calculating perplexity | ||
| * @return Variational upper bound on log perplexity per token. | ||
| */ | ||
| @Since("1.6.0") | ||
| def logPerplexity(dataset: DataFrame): Double = oldLocalModel match { | ||
| case Some(m) => | ||
| val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) | ||
| m.logPerplexity(oldDataset) | ||
| case None => | ||
| // Should never happen. | ||
| throw new RuntimeException("LocalLDAModel.logPerplexity was called," + | ||
| " but the underlying model is missing.") | ||
| def logPerplexity(dataset: DataFrame): Double = { | ||
| val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) | ||
| oldLocalModel.logPerplexity(oldDataset) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -468,70 +463,83 @@ class LDAModel private[ml] ( | |
| /** | ||
| * :: Experimental :: | ||
| * | ||
| * Distributed model fitted by [[LDA]] using Expectation-Maximization (EM). | ||
| * Local (non-distributed) model fitted by [[LDA]]. | ||
| * | ||
| * This model stores the inferred topics only; it does not store info about the training dataset. | ||
| */ | ||
| @Since("1.6.0") | ||
| @Experimental | ||
| class LocalLDAModel private[ml] ( | ||
| uid: String, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We annotated the constructor args to the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean the Since annotations? It's because those constructor args are also vals which will appear in the public API. |
||
| vocabSize: Int, | ||
| @Since("1.6.0") override protected val oldLocalModel: OldLocalLDAModel, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does a
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll have to clarify with @mengxr |
||
| sqlContext: SQLContext) | ||
| extends LDAModel(uid, vocabSize, sqlContext) { | ||
|
|
||
| @Since("1.6.0") | ||
| override def copy(extra: ParamMap): LocalLDAModel = { | ||
| val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext) | ||
| copyValues(copied, extra).setParent(parent).asInstanceOf[LocalLDAModel] | ||
| } | ||
|
|
||
| override protected def getModel: OldLDAModel = oldLocalModel | ||
|
|
||
| @Since("1.6.0") | ||
| override def isDistributed: Boolean = false | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * | ||
| * Distributed model fitted by [[LDA]]. | ||
| * This type of model is currently only produced by Expectation-Maximization (EM). | ||
| * | ||
| * This model stores the inferred topics, the full training dataset, and the topic distribution | ||
| * for each training document. | ||
| * | ||
| * @param oldLocalModelOption Used to implement [[oldLocalModel]] as a lazy val, but keeping | ||
| * [[copy()]] cheap. | ||
| */ | ||
| @Since("1.6.0") | ||
| @Experimental | ||
| class DistributedLDAModel private[ml] ( | ||
| uid: String, | ||
| vocabSize: Int, | ||
| private val oldDistributedModel: OldDistributedLDAModel, | ||
| sqlContext: SQLContext) | ||
| extends LDAModel(uid, vocabSize, None, sqlContext) { | ||
| sqlContext: SQLContext, | ||
| private var oldLocalModelOption: Option[OldLocalLDAModel]) | ||
| extends LDAModel(uid, vocabSize, sqlContext) { | ||
|
|
||
| override protected def oldLocalModel: OldLocalLDAModel = { | ||
| if (oldLocalModelOption.isEmpty) { | ||
| oldLocalModelOption = Some(oldDistributedModel.toLocal) | ||
| } | ||
| oldLocalModelOption.get | ||
| } | ||
|
|
||
| override protected def getModel: OldLDAModel = oldDistributedModel | ||
|
|
||
| /** | ||
| * Convert this distributed model to a local representation. This discards info about the | ||
| * training dataset. | ||
| * | ||
| * WARNING: This involves collecting a large [[topicsMatrix]] to the driver. | ||
| */ | ||
| @Since("1.6.0") | ||
| def toLocal: LDAModel = { | ||
| if (oldLocalModel.isEmpty) { | ||
| oldLocalModel = Some(oldDistributedModel.toLocal) | ||
| } | ||
| new LDAModel(uid, vocabSize, oldLocalModel, sqlContext) | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| override protected def getModel: OldLDAModel = oldDistributedModel | ||
| def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext) | ||
|
|
||
| @Since("1.6.0") | ||
| override def copy(extra: ParamMap): DistributedLDAModel = { | ||
| val copied = new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext) | ||
| if (oldLocalModel.nonEmpty) copied.oldLocalModel = oldLocalModel | ||
| val copied = | ||
| new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext, oldLocalModelOption) | ||
| copyValues(copied, extra).setParent(parent) | ||
| copied | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| override def topicsMatrix: Matrix = { | ||
| if (oldLocalModel.isEmpty) { | ||
| oldLocalModel = Some(oldDistributedModel.toLocal) | ||
| } | ||
| super.topicsMatrix | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| override def isDistributed: Boolean = true | ||
|
|
||
| @Since("1.6.0") | ||
| override def logLikelihood(dataset: DataFrame): Double = { | ||
| if (oldLocalModel.isEmpty) { | ||
| oldLocalModel = Some(oldDistributedModel.toLocal) | ||
| } | ||
| super.logLikelihood(dataset) | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| override def logPerplexity(dataset: DataFrame): Double = { | ||
| if (oldLocalModel.isEmpty) { | ||
| oldLocalModel = Some(oldDistributedModel.toLocal) | ||
| } | ||
| super.logPerplexity(dataset) | ||
| } | ||
|
|
||
| /** | ||
| * Log likelihood of the observed tokens in the training set, | ||
| * given the current parameter estimates: | ||
|
|
@@ -673,9 +681,9 @@ class LDA @Since("1.6.0") ( | |
| val oldModel = oldLDA.run(oldData) | ||
| val newModel = oldModel match { | ||
| case m: OldLocalLDAModel => | ||
| new LDAModel(uid, m.vocabSize, Some(m), dataset.sqlContext) | ||
| new LocalLDAModel(uid, m.vocabSize, m, dataset.sqlContext) | ||
| case m: OldDistributedLDAModel => | ||
| new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext) | ||
| new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None) | ||
| } | ||
| copyValues(newModel).setParent(this) | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we warn that this (any any method that uses this) will trigger a
collectinDistributedLDAModel(if so, does the doc belong here or inDistributedLDAModel), or can we expect users ofDistributedLDAModelunderstand that already?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not trigger a collect. It's oldLocalModel which can trigger a collect.
I'll add some more warnings in places where it can happen in the public API.