Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion R/pkg/R/mllib_clustering.R
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,13 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"),
#' \item{\code{topics}}{top 10 terms and their weights of all topics}
#' \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file
#' used as training set}
#' \item{\code{trainingLogLikelihood}}{Log likelihood of the observed tokens in the training set,
#' given the current parameter estimates:
#' log P(docs | topics, topic distributions for docs, Dirichlet hyperparameters)
#' It is only for distributed LDA model (i.e., optimizer = "em")}
#' \item{\code{logPrior}}{Log probability of the current parameter estimate:
#' log P(topics, topic distributions for docs | Dirichlet hyperparameters)
#' It is only for distributed LDA model (i.e., optimizer = "em")}
#' @rdname spark.lda
#' @aliases summary,LDAModel-method
#' @export
Expand All @@ -404,11 +411,22 @@ setMethod("summary", signature(object = "LDAModel"),
vocabSize <- callJMethod(jobj, "vocabSize")
topics <- dataFrame(callJMethod(jobj, "topics", maxTermsPerTopic))
vocabulary <- callJMethod(jobj, "vocabulary")
trainingLogLikelihood <- if (isDistributed) {
callJMethod(jobj, "trainingLogLikelihood")
} else {
NA
}
logPrior <- if (isDistributed) {
callJMethod(jobj, "logPrior")
} else {
NA
}
list(docConcentration = unlist(docConcentration),
topicConcentration = topicConcentration,
logLikelihood = logLikelihood, logPerplexity = logPerplexity,
isDistributed = isDistributed, vocabSize = vocabSize,
topics = topics, vocabulary = unlist(vocabulary))
topics = topics, vocabulary = unlist(vocabulary),
trainingLogLikelihood = trainingLogLikelihood, logPrior = logPrior)
})

# Returns the log perplexity of a Latent Dirichlet Allocation model produced by \code{spark.lda}
Expand Down
16 changes: 14 additions & 2 deletions R/pkg/inst/tests/testthat/test_mllib_clustering.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,16 @@ test_that("spark.lda with libsvm", {
topics <- stats$topicTopTerms
weights <- stats$topicTopTermsWeights
vocabulary <- stats$vocabulary
trainingLogLikelihood <- stats$trainingLogLikelihood
logPrior <- stats$logPrior

expect_false(isDistributed)
expect_true(isDistributed)
expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
expect_equal(vocabSize, 11)
expect_true(is.null(vocabulary))
expect_true(trainingLogLikelihood <= 0 & !is.na(trainingLogLikelihood))
expect_true(logPrior <= 0 & !is.na(logPrior))

# Test model save/load
modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp")
Expand All @@ -161,11 +165,13 @@ test_that("spark.lda with libsvm", {
model2 <- read.ml(modelPath)
stats2 <- summary(model2)

expect_false(stats2$isDistributed)
expect_true(stats2$isDistributed)
expect_equal(logLikelihood, stats2$logLikelihood)
expect_equal(logPerplexity, stats2$logPerplexity)
expect_equal(vocabSize, stats2$vocabSize)
expect_equal(vocabulary, stats2$vocabulary)
expect_equal(trainingLogLikelihood, stats2$trainingLogLikelihood)
expect_equal(logPrior, stats2$logPrior)

unlink(modelPath)
})
Expand All @@ -182,12 +188,16 @@ test_that("spark.lda with text input", {
topics <- stats$topicTopTerms
weights <- stats$topicTopTermsWeights
vocabulary <- stats$vocabulary
trainingLogLikelihood <- stats$trainingLogLikelihood
logPrior <- stats$logPrior

expect_false(isDistributed)
expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
expect_equal(vocabSize, 10)
expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")))
expect_true(is.na(trainingLogLikelihood))
expect_true(is.na(logPrior))

# Test model save/load
modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp")
Expand All @@ -202,6 +212,8 @@ test_that("spark.lda with text input", {
expect_equal(logPerplexity, stats2$logPerplexity)
expect_equal(vocabSize, stats2$vocabSize)
expect_true(all.equal(vocabulary, stats2$vocabulary))
expect_true(is.na(stats2$trainingLogLikelihood))
expect_true(is.na(stats2$logPrior))

unlink(modelPath)
})
Expand Down
1 change: 0 additions & 1 deletion R/pkg/inst/tests/testthat/test_mllib_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ test_that("spark.randomForest", {
63.53160, 64.05470, 65.12710, 64.30450,
66.70910, 67.86125, 68.08700, 67.21865,
68.89275, 69.53180, 69.39640, 69.68250),

tolerance = 1e-4)
stats <- summary(model)
expect_equal(stats$numTrees, 20)
Expand Down
10 changes: 9 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkException
import org.apache.spark.ml.{Pipeline, PipelineModel, PipelineStage}
import org.apache.spark.ml.clustering.{LDA, LDAModel}
import org.apache.spark.ml.clustering.{DistributedLDAModel, LDA, LDAModel}
import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.ParamPair
Expand All @@ -45,6 +45,13 @@ private[r] class LDAWrapper private (
import LDAWrapper._

private val lda: LDAModel = pipeline.stages.last.asInstanceOf[LDAModel]

// The following variables were called by R side code only when the LDA model is distributed
lazy private val distributedModel =
pipeline.stages.last.asInstanceOf[DistributedLDAModel]
lazy val trainingLogLikelihood: Double = distributedModel.trainingLogLikelihood
lazy val logPrior: Double = distributedModel.logPrior

private val preprocessor: PipelineModel =
new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", pipeline.stages.dropRight(1))

Expand Down Expand Up @@ -122,6 +129,7 @@ private[r] object LDAWrapper extends MLReadable[LDAWrapper] {
.setK(k)
.setMaxIter(maxIter)
.setSubsamplingRate(subsamplingRate)
.setOptimizer(optimizer)

val featureSchema = data.schema(features)
val stages = featureSchema.dataType match {
Expand Down