diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index fe6a37fd6dc3..4fc5a81b985a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -18,7 +18,10 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods._ +import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} @@ -457,6 +460,31 @@ sealed abstract class LDAModel private[ml] ( def describeTopics(): DataFrame = describeTopics(10) } +object LDAModel extends MLReadable[LDAModel] { + + private class LDAModelReader extends MLReader[LDAModel] { + override def load(path: String): LDAModel = { + val metadataPath = new Path(path, "metadata").toString + val metadata = parse(sc.textFile(metadataPath, 1).first()) + implicit val format = DefaultFormats + val className = (metadata \ "class").extract[String] + className match { + case c if className == classOf[LocalLDAModel].getName => + LocalLDAModel.load(path) + case c if className == classOf[DistributedLDAModel].getName => + DistributedLDAModel.load(path) + case _ => throw new SparkException(s"$className in $path is not a LDAModel") + } + } + } + + @Since("2.0.0") + override def read: MLReader[LDAModel] = new LDAModelReader + + @Since("2.0.0") + override def load(path: String): LDAModel = super.load(path) +} + /** * :: Experimental :: diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index dd3f4c6e5391..4805995c1ffa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.ml.clustering +import java.io.File + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.util.Utils object LDASuite { @@ -261,4 +264,22 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) } + + test("load LDAModel") { + val lda = new LDA().setK(k).setSeed(1).setOptimizer("em").setMaxIter(2) + val distributedModel = lda.fit(dataset) + val localModel = lda.setOptimizer("online").fit(dataset) + + val tempDir1 = Utils.createTempDir() + val distributedPath = new File(tempDir1, "distributed").getPath + val localPath = new File(tempDir1, "local").getPath + try { + distributedModel.save(distributedPath) + localModel.save(localPath) + assert(LDAModel.load(distributedPath).isInstanceOf[DistributedLDAModel]) + assert(LDAModel.load(localPath).isInstanceOf[LocalLDAModel]) + } finally { + Utils.deleteRecursively(tempDir1) + } + } }