From c957730e60cc237ce684a94e0b4867ebadd938c7 Mon Sep 17 00:00:00 2001 From: zlpmichelle Date: Mon, 27 Jun 2016 23:00:30 -0700 Subject: [PATCH 1/2] [SPARK-16241] [ML] model loading backward compatibility for ml NaiveBayes #16241 --- .../apache/spark/ml/classification/NaiveBayes.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 7c340312df3e1..431b79b9b1b58 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -28,8 +28,9 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{Row, Dataset} /** * Params for Naive Bayes Classifiers. @@ -275,9 +276,11 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head() - val pi = data.getAs[Vector](0) - val theta = data.getAs[Matrix](1) + val data = sparkSession.read.parquet(dataPath) + val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi") + val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta") + .select("pi", "theta") + .head() val model = new NaiveBayesModel(metadata.uid, pi, theta) DefaultParamsReader.getAndSetParams(model, metadata) From ce8392994cec6cafbf8e34d8eaac715481199bc7 Mon Sep 17 00:00:00 2001 From: zlpmichelle Date: Wed, 29 Jun 2016 02:30:29 -0700 Subject: [PATCH 2/2] fix Scala style issue --- .../scala/org/apache/spark/ml/classification/NaiveBayes.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 431b79b9b1b58..c99ae30155e3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesMo import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, Dataset} +import org.apache.spark.sql.{Dataset, Row} /** * Params for Naive Bayes Classifiers.