From 7ea603acd00846e3984105df75f85fda855c405a Mon Sep 17 00:00:00 2001 From: Christoph Sawade Date: Fri, 22 Aug 2014 22:38:40 +0200 Subject: [PATCH 1/5] Clarify learning interfaces * Make threshold mandatory Currently, the output of ``predict`` for an example is either the score or the class. This side-effect is caused by ``clearThreshold``. To clarify that behaviour three different types of predict (predictScore, predictClass, predictProbabilty) were introduced; the threshold is not longer optional. * Clarify classification interfaces Currently, some functionality is spreaded over multiple models. In order to clarify the structure and simplify the implementation of more complex models (like multinomial logistic regression), two new classes are introduced: - BinaryClassificationModel: for all models that derives a binary classification from a single weight vector. Comprises the tresholding functionality to derive a prediction from a score. It basically captures SVMModel and LogisticRegressionModel. - ProbabilitistClassificaitonModel: This trait defines the interface for models that return a calibrated confidence score (aka probability). * Misc - some renaming - add test for probabilistic output --- .../examples/mllib/BinaryClassification.scala | 6 +- .../examples/mllib/LinearRegression.scala | 2 +- .../examples/mllib/SparseNaiveBayes.scala | 2 +- .../BinaryClassificationModel.scala | 62 +++++++++++++++++++ .../classification/ClassificationModel.scala | 8 +-- .../classification/LogisticRegression.scala | 39 +++--------- .../mllib/classification/NaiveBayes.scala | 6 +- .../ProbabilisticClassificationModel.scala | 46 ++++++++++++++ .../spark/mllib/classification/SVM.scala | 60 ++---------------- .../GeneralizedLinearAlgorithm.scala | 12 ++-- .../apache/spark/mllib/regression/Lasso.scala | 2 +- .../mllib/regression/LinearRegression.scala | 2 +- .../mllib/regression/RegressionModel.scala | 8 +-- .../mllib/regression/RidgeRegression.scala | 2 +- .../regression/StreamingLinearAlgorithm.scala | 4 +- .../JavaLogisticRegressionSuite.java | 2 +- .../mllib/regression/JavaLassoSuite.java | 2 +- .../regression/JavaLinearRegressionSuite.java | 4 +- .../regression/JavaRidgeRegressionSuite.java | 2 +- .../LogisticRegressionSuite.scala | 42 +++++++++---- .../classification/NaiveBayesSuite.scala | 6 +- .../spark/mllib/classification/SVMSuite.scala | 16 ++--- .../spark/mllib/regression/LassoSuite.scala | 10 +-- .../regression/LinearRegressionSuite.scala | 14 ++--- .../regression/RidgeRegressionSuite.scala | 6 +- .../StreamingLinearRegressionSuite.scala | 2 +- 26 files changed, 213 insertions(+), 154 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/classification/BinaryClassificationModel.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/classification/ProbabilisticClassificationModel.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index a6f78d2441db..3cd13e795097 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -131,7 +131,7 @@ object BinaryClassification { .setNumIterations(params.numIterations) .setUpdater(updater) .setRegParam(params.regParam) - algorithm.run(training).clearThreshold() + algorithm.run(training) case SVM => val algorithm = new SVMWithSGD() algorithm.optimizer @@ -139,10 +139,10 @@ object BinaryClassification { .setStepSize(params.stepSize) .setUpdater(updater) .setRegParam(params.regParam) - algorithm.run(training).clearThreshold() + algorithm.run(training) } - val prediction = model.predict(test.map(_.features)) + val prediction = model.predictClass(test.map(_.features)) val predictionAndLabel = prediction.zip(test.map(_.label)) val metrics = new BinaryClassificationMetrics(predictionAndLabel) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 05b7d66f8dff..c868976243aa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -118,7 +118,7 @@ object LinearRegression extends App { val model = algorithm.run(training) - val prediction = model.predict(test.map(_.features)) + val prediction = model.predictScore(test.map(_.features)) val predictionAndLabel = prediction.zip(test.map(_.label)) val loss = predictionAndLabel.map { case (p, l) => diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index 952fa2a5109a..532045221d0a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -91,7 +91,7 @@ object SparseNaiveBayes { val model = new NaiveBayes().setLambda(params.lambda).run(training) - val prediction = model.predict(test.map(_.features)) + val prediction = model.predictClass(test.map(_.features)) val predictionAndLabel = prediction.zip(test.map(_.label)) val accuracy = predictionAndLabel.filter(x => x._1 == x._2).count().toDouble / numTest diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/BinaryClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/BinaryClassificationModel.scala new file mode 100644 index 000000000000..a331908797ea --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/BinaryClassificationModel.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.regression.GeneralizedLinearModel + +/** + * :: Experimental :: + * Represents a classification model that predicts to which of a set of categories an example + * belongs. The categories are represented by double values: 0.0, 1.0 + */ +@Experimental +class BinaryClassificationModel ( + override val weights: Vector, + override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel[Double] with Serializable { + + protected var threshold: Double = 0.0 + + /** + * :: Experimental :: + * Setter and getter for the threshold. The threshold separates positive predictions from + * negative predictions. An example with prediction score greater than or equal to this + * threshold is identified as an positive, and negative otherwise. The default value is 0.5. + */ + @Experimental + def setThreshold(threshold: Double): this.type = { + this.threshold = threshold + this + } + + def getThreshold = threshold + + private def compareWithThreshold(value: Double): Double = + if (value < threshold) 0.0 else 1.0 + + def predictClass(testData: RDD[Vector]): RDD[Double] = { + predictScore(testData).map(compareWithThreshold) + } + + def predictClass(testData: Vector): Double = { + compareWithThreshold(predictScore(testData)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index b7a1d90d24d7..f88c1299f81e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -30,15 +30,15 @@ import org.apache.spark.rdd.RDD @Experimental trait ClassificationModel extends Serializable { /** - * Predict values for the given data set using the model trained. + * Classify the given data set using the model trained. * - * @param testData RDD representing data points to be predicted + * @param testData RDD representing data points to be classified * @return an RDD[Double] where each entry contains the corresponding prediction */ - def predict(testData: RDD[Vector]): RDD[Double] + def predictClass(testData: RDD[Vector]): RDD[Double] /** - * Predict values for a single data point using the model trained. + * Classify a single data point using the model trained. * * @param testData array representing a single data point * @return predicted category from the trained model diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 486bdbfa9cb4..03e2c3b5a02c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -17,7 +17,6 @@ package org.apache.spark.mllib.classification -import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ @@ -33,45 +32,23 @@ import org.apache.spark.rdd.RDD class LogisticRegressionModel ( override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { + extends BinaryClassificationModel(weights, intercept) with ProbabilisticClassificationModel { - private var threshold: Option[Double] = Some(0.5) - - /** - * :: Experimental :: - * Sets the threshold that separates positive predictions from negative predictions. An example - * with prediction score greater than or equal to this threshold is identified as an positive, - * and negative otherwise. The default value is 0.5. - */ - @Experimental - def setThreshold(threshold: Double): this.type = { - this.threshold = Some(threshold) - this + protected def computeProbability(value: Double) = { + 1.0 / (1.0 + math.exp(-value)) } - /** - * :: Experimental :: - * Clears the threshold so that `predict` will output raw prediction scores. - */ - @Experimental - def clearThreshold(): this.type = { - threshold = None - this + def predictProbability(testData: RDD[Vector]): RDD[Double] = { + predictScore(testData).map(computeProbability) } - override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, - intercept: Double) = { - val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept - val score = 1.0 / (1.0 + math.exp(-margin)) - threshold match { - case Some(t) => if (score < t) 0.0 else 1.0 - case None => score - } + def predictProbability(testData: Vector): Double = { + computeProbability(predictScore(testData)) } } /** - * Train a classification model for Logistic Regression using Stochastic Gradient Descent. + * Train a classification model for Logistic Regression using limited-memory Broyden–Fletcher–Goldfarb–Shanno algorithm. * NOTE: Labels used in Logistic Regression should be {0, 1} * * Using [[LogisticRegressionWithLBFGS]] is recommended over this. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 8c8e4a161aa5..dfdcab353884 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -54,15 +54,15 @@ class NaiveBayesModel private[mllib] ( } } - override def predict(testData: RDD[Vector]): RDD[Double] = { + override def predictClass(testData: RDD[Vector]): RDD[Double] = { val bcModel = testData.context.broadcast(this) testData.mapPartitions { iter => val model = bcModel.value - iter.map(model.predict) + iter.map(model.predictClass) } } - override def predict(testData: Vector): Double = { + override def predictClass(testData: Vector): Double = { labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ProbabilisticClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ProbabilisticClassificationModel.scala new file mode 100644 index 000000000000..4edb5b59abf2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ProbabilisticClassificationModel.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Represents a probabilistic classification model that provides a probability + * distribution over a set of classes, rather than only predicting a class. + */ +@Experimental +trait ProbabilisticClassificationModel extends ClassificationModel { + /** + * Return probability for the prediction of the given data set using the model trained. + * + * @param testData RDD representing data points to be classified + * @return an RDD[Double] where each entry contains the corresponding prediction + */ + def predictProbability(testData: RDD[Vector]): RDD[Double] + + /** + * Return probability for a single data point prediction using the model trained. + * + * @param testData array representing a single data point + * @return predicted category from the trained model + */ + def predictProbability(testData: Vector): Double +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 80f8a1b2f1e8..c3b73e783e86 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -17,60 +17,12 @@ package org.apache.spark.mllib.classification -import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.DataValidators import org.apache.spark.rdd.RDD -/** - * Model for Support Vector Machines (SVMs). - * - * @param weights Weights computed for every feature. - * @param intercept Intercept computed for this model. - */ -class SVMModel ( - override val weights: Vector, - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { - - private var threshold: Option[Double] = Some(0.0) - - /** - * :: Experimental :: - * Sets the threshold that separates positive predictions from negative predictions. An example - * with prediction score greater than or equal to this threshold is identified as an positive, - * and negative otherwise. The default value is 0.0. - */ - @Experimental - def setThreshold(threshold: Double): this.type = { - this.threshold = Some(threshold) - this - } - - /** - * :: Experimental :: - * Clears the threshold so that `predict` will output raw prediction scores. - */ - @Experimental - def clearThreshold(): this.type = { - threshold = None - this - } - - override protected def predictPoint( - dataMatrix: Vector, - weightMatrix: Vector, - intercept: Double) = { - val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept - threshold match { - case Some(t) => if (margin < t) 0.0 else 1.0 - case None => margin - } - } -} - /** * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. * NOTE: Labels used in SVM should be {0, 1}. @@ -80,7 +32,7 @@ class SVMWithSGD private ( private var numIterations: Int, private var regParam: Double, private var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[SVMModel] with Serializable { + extends GeneralizedLinearAlgorithm[BinaryClassificationModel] with Serializable { private val gradient = new HingeGradient() private val updater = new SquaredL2Updater() @@ -97,7 +49,7 @@ class SVMWithSGD private ( def this() = this(1.0, 100, 1.0, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { - new SVMModel(weights, intercept) + new BinaryClassificationModel(weights, intercept) } } @@ -128,7 +80,7 @@ object SVMWithSGD { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Vector): SVMModel = { + initialWeights: Vector): BinaryClassificationModel = { new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction) .run(input, initialWeights) } @@ -150,7 +102,7 @@ object SVMWithSGD { numIterations: Int, stepSize: Double, regParam: Double, - miniBatchFraction: Double): SVMModel = { + miniBatchFraction: Double): BinaryClassificationModel = { new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) } @@ -170,7 +122,7 @@ object SVMWithSGD { input: RDD[LabeledPoint], numIterations: Int, stepSize: Double, - regParam: Double): SVMModel = { + regParam: Double): BinaryClassificationModel = { train(input, numIterations, stepSize, regParam, 1.0) } @@ -184,7 +136,7 @@ object SVMWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. */ - def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { + def train(input: RDD[LabeledPoint], numIterations: Int): BinaryClassificationModel = { train(input, numIterations, 1.0, 1.0, 1.0) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 20c1fdd2269c..6b1872f1110b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -45,7 +45,9 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * @param weightMatrix Column vector containing the weights of the model * @param intercept Intercept of the model. */ - protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double + protected def computeScore(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double = { + weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept + } /** * Predict values for the given data set using the model trained. @@ -53,7 +55,7 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * @param testData RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction */ - def predict(testData: RDD[Vector]): RDD[Double] = { + def predictScore(testData: RDD[Vector]): RDD[Double] = { // A small optimization to avoid serializing the entire model. Only the weightsMatrix // and intercept is needed. val localWeights = weights @@ -61,7 +63,7 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double val localIntercept = intercept testData.mapPartitions { iter => val w = bcWeights.value - iter.map(v => predictPoint(v, w, localIntercept)) + iter.map(v => computeScore(v, w, localIntercept)) } } @@ -71,8 +73,8 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * @param testData array representing a single data point * @return Double prediction from the trained model */ - def predict(testData: Vector): Double = { - predictPoint(testData, weights, intercept) + def predictScore(testData: Vector): Double = { + computeScore(testData, weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index cb0d39e759a9..912de6651aa7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -34,7 +34,7 @@ class LassoModel ( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override protected def predictPoint( + override protected def computeScore( dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 81b6598377ff..6702675d7d21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -32,7 +32,7 @@ class LinearRegressionModel ( override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override protected def predictPoint( + override protected def computeScore( dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index 64b02f7a6e7a..9a4c80342d6e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -30,7 +30,7 @@ trait RegressionModel extends Serializable { * @param testData RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction */ - def predict(testData: RDD[Vector]): RDD[Double] + def predictScore(testData: RDD[Vector]): RDD[Double] /** * Predict values for a single data point using the model trained. @@ -38,13 +38,13 @@ trait RegressionModel extends Serializable { * @param testData array representing a single data point * @return Double prediction from the trained model */ - def predict(testData: Vector): Double + def predictScore(testData: Vector): Double /** * Predict values for examples stored in a JavaRDD. * @param testData JavaRDD representing data points to be predicted * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction */ - def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = - predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] + def predictScore(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = + predictScore(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index a826deb695ee..4d5d00f5349c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -34,7 +34,7 @@ class RidgeRegressionModel ( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override protected def predictPoint( + override protected def computeScore( dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index 8db0442a7a56..3f3eb6620ebf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -105,7 +105,7 @@ abstract class StreamingLinearAlgorithm[ logError(msg) throw new IllegalArgumentException(msg) } - data.map(model.predict) + data.map(model.predictScore) } /** @@ -120,6 +120,6 @@ abstract class StreamingLinearAlgorithm[ logError(msg) throw new IllegalArgumentException(msg) } - data.mapValues(model.predict) + data.mapValues(model.predictScore) } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java index 862221d48798..49c743cbd180 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java @@ -47,7 +47,7 @@ public void tearDown() { int validatePrediction(List validationData, LogisticRegressionModel model) { int numAccurate = 0; for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); + Double prediction = model.predictScore(point.features()); if (prediction == point.label()) { numAccurate++; } diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java index 8950b48888b7..d144f9e93437 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java @@ -46,7 +46,7 @@ public void tearDown() { int validatePrediction(List validationData, LassoModel model) { int numAccurate = 0; for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); + Double prediction = model.predictScore(point.features()); // A prediction is off if the prediction is more than 0.5 away from expected value. if (Math.abs(prediction - point.label()) <= 0.5) { numAccurate++; diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java index 24c4c20d9af1..a53e5b4dea1a 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java @@ -48,7 +48,7 @@ public void tearDown() { int validatePrediction(List validationData, LinearRegressionModel model) { int numAccurate = 0; for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); + Double prediction = model.predictScore(point.features()); // A prediction is off if the prediction is more than 0.5 away from expected value. if (Math.abs(prediction - point.label()) <= 0.5) { numAccurate++; @@ -108,7 +108,7 @@ public Vector call(LabeledPoint v) throws Exception { return v.features(); } }); - JavaRDD predictions = model.predict(vectors); + JavaRDD predictions = model.predictScore(vectors); // Should be able to get the first prediction. predictions.first(); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java index 7266eec23580..76da3b2e240e 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java @@ -48,7 +48,7 @@ public void tearDown() { double predictionError(List validationData, RidgeRegressionModel model) { double errorSum = 0; for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); + Double prediction = model.predictScore(point.features()); errorSum += (prediction - point.label()) * (prediction - point.label()); } return errorSum / validationData.size(); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 862178694a50..f9d357d4e21a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConversions._ import org.scalatest.FunSuite import org.scalatest.Matchers -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} import org.apache.spark.mllib.util.TestingUtils._ @@ -66,6 +66,26 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match ((input.length - numOffPredictions).toDouble / input.length) should be > 0.83 } + test("logistic output") { + val EPS = 1E-8 + + def prob2score(prob: Double) = -math.log(1.0 / prob - 1) + + val expectedProbs = Seq(0.0, 0.05, 0.2, 0.5, 0.8, 1.0) + + val model = new LogisticRegressionModel(Vectors.dense(Array(-1.0, 1.0)), 1.0) // first feature should cancel intercept + val testData = expectedProbs.map {prob => + val score = prob2score(prob) + Vectors.dense(Array(1.0, score)) + }.toSeq + + val probs = model.predictProbability(sc.parallelize(testData)).collect().toSeq + + probs.zip(expectedProbs).foreach { case (actual: Double, expected: Double) => + Math.abs(actual - expected) should be < EPS + } + } + // Test if we can correctly learn A, B where Y = logistic(A + B*X) test("logistic regression with SGD") { val nPoints = 10000 @@ -88,10 +108,10 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction(model.predictClass(validationRDD.map(_.features)).collect(), validationData) // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + validatePrediction(validationData.map(row => model.predictClass(row.features)), validationData) } // Test if we can correctly learn A, B where Y = logistic(A + B*X) @@ -117,10 +137,10 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction(model.predictClass(validationRDD.map(_.features)).collect(), validationData) // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + validatePrediction(validationData.map(row => model.predictClass(row.features)), validationData) } test("logistic regression with initial weights with SGD") { @@ -149,10 +169,10 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction(model.predictClass(validationRDD.map(_.features)).collect(), validationData) // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + validatePrediction(validationData.map(row => model.predictClass(row.features)), validationData) } test("logistic regression with initial weights with LBFGS") { @@ -180,10 +200,10 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction(model.predictClass(validationRDD.map(_.features)).collect(), validationData) // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + validatePrediction(validationData.map(row => model.predictClass(row.features)), validationData) } test("numerical stability of scaling features using logistic regression with LBFGS") { @@ -257,7 +277,7 @@ class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkCont // greater than 1MB and hence Spark would throw an error. val model = LogisticRegressionWithSGD.train(points, 2) - val predictions = model.predict(points.map(_.features)) + val predictions = model.predictScore(points.map(_.features)) // Materialize the RDDs predictions.count() @@ -276,7 +296,7 @@ class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkCont lr.optimizer.setNumIterations(2) val model = lr.run(points) - val predictions = model.predict(points.map(_.features)) + val predictions = model.predictScore(points.map(_.features)) // Materialize the RDDs predictions.count() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 80989bc074e8..49f3dc1e1e5a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -91,10 +91,10 @@ class NaiveBayesSuite extends FunSuite with LocalSparkContext { val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction(model.predictClass(validationRDD.map(_.features)).collect(), validationData) // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + validatePrediction(validationData.map(row => model.predictClass(row.features)), validationData) } test("detect negative values") { @@ -139,6 +139,6 @@ class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. val model = NaiveBayes.train(examples) - val predictions = model.predict(examples.map(_.features)) + val predictions = model.predictClass(examples.map(_.features)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 65e5df58db4c..f416b447e752 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -91,17 +91,17 @@ class SVMSuite extends FunSuite with LocalSparkContext { // Test prediction on RDD. - var predictions = model.predict(validationRDD.map(_.features)).collect() + var predictions = model.predictClass(validationRDD.map(_.features)).collect() assert(predictions.count(_ == 0.0) != predictions.length) // High threshold makes all the predictions 0.0 model.setThreshold(10000.0) - predictions = model.predict(validationRDD.map(_.features)).collect() + predictions = model.predictClass(validationRDD.map(_.features)).collect() assert(predictions.count(_ == 0.0) == predictions.length) // Low threshold makes all the predictions 1.0 model.setThreshold(-10000.0) - predictions = model.predict(validationRDD.map(_.features)).collect() + predictions = model.predictClass(validationRDD.map(_.features)).collect() assert(predictions.count(_ == 1.0) == predictions.length) } @@ -127,10 +127,10 @@ class SVMSuite extends FunSuite with LocalSparkContext { val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction(model.predictClass(validationRDD.map(_.features)).collect(), validationData) // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + validatePrediction(validationData.map(row => model.predictClass(row.features)), validationData) } test("SVM local random SGD with initial weights") { @@ -159,10 +159,10 @@ class SVMSuite extends FunSuite with LocalSparkContext { val validationRDD = sc.parallelize(validationData,2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction(model.predictClass(validationRDD.map(_.features)).collect(), validationData) // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + validatePrediction(validationData.map(row => model.predictClass(row.features)), validationData) } test("SVM with invalid labels") { @@ -205,6 +205,6 @@ class SVMClusterSuite extends FunSuite with LocalClusterSparkContext { // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. val model = SVMWithSGD.train(points, 2) - val predictions = model.predict(points.map(_.features)) + val predictions = model.predictClass(points.map(_.features)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 7aa96421aed8..05b297cb7dd4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -67,10 +67,10 @@ class LassoSuite extends FunSuite with LocalSparkContext { val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction(model.predictScore(validationRDD.map(_.features)).collect(), validationData) // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + validatePrediction(validationData.map(row => model.predictScore(row.features)), validationData) } test("Lasso local random SGD with initial weights") { @@ -110,10 +110,10 @@ class LassoSuite extends FunSuite with LocalSparkContext { val validationRDD = sc.parallelize(validationData,2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction(model.predictScore(validationRDD.map(_.features)).collect(), validationData) // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + validatePrediction(validationData.map(row => model.predictScore(row.features)), validationData) } } @@ -129,6 +129,6 @@ class LassoClusterSuite extends FunSuite with LocalClusterSparkContext { // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. val model = LassoWithSGD.train(points, 2) - val predictions = model.predict(points.map(_.features)) + val predictions = model.predictScore(points.map(_.features)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 4f89112b650c..3d776484268f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -56,10 +56,10 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext { val validationRDD = sc.parallelize(validationData, 2).cache() // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction(model.predictScore(validationRDD.map(_.features)).collect(), validationData) // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + validatePrediction(validationData.map(row => model.predictScore(row.features)), validationData) } // Test if we can correctly learn Y = 10*X1 + 10*X2 @@ -83,10 +83,10 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext { val validationRDD = sc.parallelize(validationData, 2).cache() // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction(model.predictScore(validationRDD.map(_.features)).collect(), validationData) // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + validatePrediction(validationData.map(row => model.predictScore(row.features)), validationData) } // Test if we can correctly learn Y = 10*X1 + 10*X10000 @@ -118,11 +118,11 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext { // Test prediction on RDD. validatePrediction( - model.predict(sparseValidationRDD.map(_.features)).collect(), sparseValidationData) + model.predictScore(sparseValidationRDD.map(_.features)).collect(), sparseValidationData) // Test prediction on Array. validatePrediction( - sparseValidationData.map(row => model.predict(row.features)), sparseValidationData) + sparseValidationData.map(row => model.predictScore(row.features)), sparseValidationData) } } @@ -138,6 +138,6 @@ class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContex // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. val model = LinearRegressionWithSGD.train(points, 2) - val predictions = model.predict(points.map(_.features)) + val predictions = model.predictScore(points.map(_.features)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 727bbd051ff1..25c318a4a6df 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -61,7 +61,7 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext { val linearModel = linearReg.run(testRDD) val linearErr = predictionError( - linearModel.predict(validationRDD.map(_.features)).collect(), validationData) + linearModel.predictScore(validationRDD.map(_.features)).collect(), validationData) val ridgeReg = new RidgeRegressionWithSGD() ridgeReg.optimizer.setNumIterations(200) @@ -69,7 +69,7 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext { .setStepSize(1.0) val ridgeModel = ridgeReg.run(testRDD) val ridgeErr = predictionError( - ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData) + ridgeModel.predictScore(validationRDD.map(_.features)).collect(), validationData) // Ridge validation error should be lower than linear regression. assert(ridgeErr < linearErr, @@ -89,6 +89,6 @@ class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. val model = RidgeRegressionWithSGD.train(points, 2) - val predictions = model.predict(points.map(_.features)) + val predictions = model.predictScore(points.map(_.features)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 03b71301e9ab..22d039a361d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -75,7 +75,7 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { // check accuracy of predictions val validationData = LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 17) - validatePrediction(validationData.map(row => model.latestModel().predict(row.features)), + validatePrediction(validationData.map(row => model.latestModel().predictScore(row.features)), validationData) } From 6054b9ce0d2f436b094505dda21fd564058b1c9a Mon Sep 17 00:00:00 2001 From: Christoph Sawade Date: Wed, 27 Aug 2014 15:08:48 +0200 Subject: [PATCH 2/5] Fix test issues - scalastyle issues - java test suite - java logistic regression suite - add deprecated versions --- .../BinaryClassificationModel.scala | 58 +++++++++++++++++-- .../classification/ClassificationModel.scala | 6 +- .../classification/LogisticRegression.scala | 12 +++- .../mllib/classification/NaiveBayes.scala | 5 ++ .../spark/mllib/classification/SVM.scala | 33 +++++++++-- .../GeneralizedLinearAlgorithm.scala | 46 ++++++++++++++- .../mllib/regression/LinearRegression.scala | 12 ++-- .../mllib/regression/RidgeRegression.scala | 14 ++--- .../JavaLogisticRegressionSuite.java | 2 +- .../classification/JavaNaiveBayesSuite.java | 4 +- .../mllib/classification/JavaSVMSuite.java | 8 +-- 11 files changed, 162 insertions(+), 38 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/BinaryClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/BinaryClassificationModel.scala index a331908797ea..503b913dfcbe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/BinaryClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/BinaryClassificationModel.scala @@ -17,32 +17,32 @@ package org.apache.spark.mllib.classification -import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.GeneralizedLinearModel +import org.apache.spark.api.java.JavaRDD +import scala.deprecated /** - * :: Experimental :: * Represents a classification model that predicts to which of a set of categories an example * belongs. The categories are represented by double values: 0.0, 1.0 */ -@Experimental class BinaryClassificationModel ( override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with ClassificationModel[Double] with Serializable { + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { protected var threshold: Double = 0.0 + @deprecated + protected var useThreshold: Boolean = true /** - * :: Experimental :: * Setter and getter for the threshold. The threshold separates positive predictions from * negative predictions. An example with prediction score greater than or equal to this * threshold is identified as an positive, and negative otherwise. The default value is 0.5. */ - @Experimental def setThreshold(threshold: Double): this.type = { + this.useThreshold = true this.threshold = threshold this } @@ -59,4 +59,50 @@ class BinaryClassificationModel ( def predictClass(testData: Vector): Double = { compareWithThreshold(predictScore(testData)) } + + /** + * :: Deprecated :: + * Clears the threshold so that `predict` will output raw prediction scores. + */ + @Deprecated + def clearThreshold(): this.type = { + this.useThreshold = false + this + } + + /** + * :: Deprecated :: + * Predict values for the given data set using the model trained. + * + * @param testData RDD representing data points to be predicted + * @return an RDD[Double] where each entry contains the corresponding prediction + */ + @deprecated + override def predict(testData: RDD[Vector]): RDD[Double] = { + if (useThreshold) predictClass(testData) + else predictScore(testData) + } + + /** + * :: Deprecated :: + * Predict values for a single data point using the model trained. + * + * @param testData array representing a single data point + * @return predicted category from the trained model + */ + @deprecated + def predict(testData: Vector): Double = { + if (useThreshold) predictClass(testData) + else predictScore(testData) + } + + /** + * :: Deprecated :: + * Predict values for examples stored in a JavaRDD. + * @param testData JavaRDD representing data points to be predicted + * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction + */ + @deprecated + def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = + predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index f88c1299f81e..b6c1789efa57 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -43,13 +43,13 @@ trait ClassificationModel extends Serializable { * @param testData array representing a single data point * @return predicted category from the trained model */ - def predict(testData: Vector): Double + def predictClass(testData: Vector): Double /** * Predict values for examples stored in a JavaRDD. * @param testData JavaRDD representing data points to be predicted * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction */ - def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = - predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] + def predictClass(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = + predictClass(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 03e2c3b5a02c..d00f0253b356 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -45,13 +45,21 @@ class LogisticRegressionModel ( def predictProbability(testData: Vector): Double = { computeProbability(predictScore(testData)) } + + @deprecated + override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, + intercept: Double) = { + if (useThreshold) predictClass(dataMatrix) + else predictProbability(dataMatrix) + } } /** - * Train a classification model for Logistic Regression using limited-memory Broyden–Fletcher–Goldfarb–Shanno algorithm. + * Train a classification model for Logistic Regression using limited-memory + * Broyden–Fletcher–Goldfarb–Shanno algorithm. * NOTE: Labels used in Logistic Regression should be {0, 1} * - * Using [[LogisticRegressionWithLBFGS]] is recommended over this. + * Using [[org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS]] is recommended over this. */ class LogisticRegressionWithSGD private ( private var stepSize: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index dfdcab353884..a497bb53d5b1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -65,6 +65,11 @@ class NaiveBayesModel private[mllib] ( override def predictClass(testData: Vector): Double = { labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) } + + @deprecated + def predict(testData: RDD[Vector]): RDD[Double] = predictClass(testData) + @deprecated + def predict(testData: Vector): Double = predictClass(testData) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index c3b73e783e86..c3e60f1e9b2c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -23,6 +23,27 @@ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.DataValidators import org.apache.spark.rdd.RDD +/** + * Model for Support Vector Machines (SVMs). + * + * @param weights Weights computed for every feature. + * @param intercept Intercept computed for this model. + */ +class SVMModel ( + override val weights: Vector, + override val intercept: Double) + extends BinaryClassificationModel(weights, intercept) { + + @deprecated + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, + intercept: Double) = { + if (useThreshold) predictClass(dataMatrix) + else predictScore(dataMatrix) + } +} + /** * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. * NOTE: Labels used in SVM should be {0, 1}. @@ -32,7 +53,7 @@ class SVMWithSGD private ( private var numIterations: Int, private var regParam: Double, private var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[BinaryClassificationModel] with Serializable { + extends GeneralizedLinearAlgorithm[SVMModel] with Serializable { private val gradient = new HingeGradient() private val updater = new SquaredL2Updater() @@ -49,7 +70,7 @@ class SVMWithSGD private ( def this() = this(1.0, 100, 1.0, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { - new BinaryClassificationModel(weights, intercept) + new SVMModel(weights, intercept) } } @@ -80,7 +101,7 @@ object SVMWithSGD { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Vector): BinaryClassificationModel = { + initialWeights: Vector): SVMModel = { new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction) .run(input, initialWeights) } @@ -102,7 +123,7 @@ object SVMWithSGD { numIterations: Int, stepSize: Double, regParam: Double, - miniBatchFraction: Double): BinaryClassificationModel = { + miniBatchFraction: Double): SVMModel = { new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) } @@ -122,7 +143,7 @@ object SVMWithSGD { input: RDD[LabeledPoint], numIterations: Int, stepSize: Double, - regParam: Double): BinaryClassificationModel = { + regParam: Double): SVMModel = { train(input, numIterations, stepSize, regParam, 1.0) } @@ -136,7 +157,7 @@ object SVMWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. */ - def train(input: RDD[LabeledPoint], numIterations: Int): BinaryClassificationModel = { + def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { train(input, numIterations, 1.0, 1.0, 1.0) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 6b1872f1110b..a9685c815554 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -45,7 +45,8 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * @param weightMatrix Column vector containing the weights of the model * @param intercept Intercept of the model. */ - protected def computeScore(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double = { + protected def computeScore( + dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } @@ -76,6 +77,49 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double def predictScore(testData: Vector): Double = { computeScore(testData, weights, intercept) } + + /** + * : Deprecated : + * Predict the result given a data point and the weights learned. + * + * @param dataMatrix Row vector containing the features for this data point + * @param weightMatrix Column vector containing the weights of the model + * @param intercept Intercept of the model. + */ + @deprecated + protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double + + /** + * : Deprecated : + * Predict values for the given data set using the model trained. + * + * @param testData RDD representing data points to be predicted + * @return RDD[Double] where each entry contains the corresponding prediction + */ + @deprecated + def predict(testData: RDD[Vector]): RDD[Double] = { + // A small optimization to avoid serializing the entire model. Only the weightsMatrix + // and intercept is needed. + val localWeights = weights + val bcWeights = testData.context.broadcast(localWeights) + val localIntercept = intercept + testData.mapPartitions { iter => + val w = bcWeights.value + iter.map(v => predictPoint(v, w, localIntercept)) + } + } + + /** + * : Deprecated : + * Predict values for a single data point using the model trained. + * + * @param testData array representing a single data point + * @return Double prediction from the trained model + */ + @deprecated + def predict(testData: Vector): Double = { + predictPoint(testData, weights, intercept) + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 6702675d7d21..3867639dc919 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -32,12 +32,12 @@ class LinearRegressionModel ( override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override protected def computeScore( - dataMatrix: Vector, - weightMatrix: Vector, - intercept: Double): Double = { - weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept - } + @deprecated + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, + intercept: Double) = + predictScore(dataMatrix) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 4d5d00f5349c..7a3bd0c74ee4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -17,7 +17,6 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.Vector @@ -34,12 +33,13 @@ class RidgeRegressionModel ( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override protected def computeScore( - dataMatrix: Vector, - weightMatrix: Vector, - intercept: Double): Double = { - weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept - } + @deprecated + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, + intercept: Double) = + predictScore(dataMatrix) + } /** diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java index 49c743cbd180..d0c59090e79b 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java @@ -47,7 +47,7 @@ public void tearDown() { int validatePrediction(List validationData, LogisticRegressionModel model) { int numAccurate = 0; for (LabeledPoint point: validationData) { - Double prediction = model.predictScore(point.features()); + Double prediction = model.predictClass(point.features()); if (prediction == point.label()) { numAccurate++; } diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 1c90522a0714..be0a5eb8cd46 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -58,7 +58,7 @@ public void tearDown() { private int validatePrediction(List points, NaiveBayesModel model) { int correct = 0; for (LabeledPoint p: points) { - if (model.predict(p.features()) == p.label()) { + if (model.predictClass(p.features()) == p.label()) { correct += 1; } } @@ -98,7 +98,7 @@ public void testPredictJavaRDD() { public Vector call(LabeledPoint v) throws Exception { return v.features(); }}); - JavaRDD predictions = model.predict(vectors); + JavaRDD predictions = model.predictClass(vectors); // Should be able to get the first prediction. predictions.first(); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java index 31b9f3e8d438..cf9ba6d83996 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java @@ -43,10 +43,10 @@ public void tearDown() { sc = null; } - int validatePrediction(List validationData, SVMModel model) { + int validatePrediction(List validationData, BinaryClassificationModel model) { int numAccurate = 0; for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); + Double prediction = model.predictClass(point.features()); if (prediction == point.label()) { numAccurate++; } @@ -70,7 +70,7 @@ public void runSVMUsingConstructor() { svmSGDImpl.optimizer().setStepSize(1.0) .setRegParam(1.0) .setNumIterations(100); - SVMModel model = svmSGDImpl.run(testRDD.rdd()); + BinaryClassificationModel model = svmSGDImpl.run(testRDD.rdd()); int numAccurate = validatePrediction(validationData, model); Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); @@ -87,7 +87,7 @@ public void runSVMUsingStaticMethods() { List validationData = SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); - SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0); + BinaryClassificationModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0); int numAccurate = validatePrediction(validationData, model); Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); From 0c4341e57318d1328698b1e0639ae537d927ac65 Mon Sep 17 00:00:00 2001 From: Christoph Sawade Date: Mon, 1 Sep 2014 15:15:42 +0200 Subject: [PATCH 3/5] WIP --- .../apache/spark/mllib/classification/LogisticRegression.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index d00f0253b356..3cb657f1aed8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -59,7 +59,7 @@ class LogisticRegressionModel ( * Broyden–Fletcher–Goldfarb–Shanno algorithm. * NOTE: Labels used in Logistic Regression should be {0, 1} * - * Using [[org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS]] is recommended over this. + * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ class LogisticRegressionWithSGD private ( private var stepSize: Double, From 140f09c694ceb64cb246e1aed5b1fc7035d91a18 Mon Sep 17 00:00:00 2001 From: Christoph Sawade Date: Mon, 1 Sep 2014 16:53:12 +0200 Subject: [PATCH 4/5] WIP --- .../BinaryClassificationModel.scala | 32 +++++++++++++------ .../classification/LogisticRegression.scala | 5 ++- .../mllib/classification/NaiveBayes.scala | 9 +++++- .../spark/mllib/classification/SVM.scala | 9 ------ .../GeneralizedLinearAlgorithm.scala | 6 ++-- .../apache/spark/mllib/regression/Lasso.scala | 16 ++++++---- .../mllib/regression/LinearRegression.scala | 6 +++- .../mllib/regression/RidgeRegression.scala | 7 ++-- 8 files changed, 57 insertions(+), 33 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/BinaryClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/BinaryClassificationModel.scala index 503b913dfcbe..a382ce8c24c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/BinaryClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/BinaryClassificationModel.scala @@ -21,7 +21,6 @@ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.GeneralizedLinearModel import org.apache.spark.api.java.JavaRDD -import scala.deprecated /** * Represents a classification model that predicts to which of a set of categories an example @@ -33,7 +32,8 @@ class BinaryClassificationModel ( extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { protected var threshold: Double = 0.0 - @deprecated + + // this is only used to ensure prior behaviour of deprecated `predict`` protected var useThreshold: Boolean = true /** @@ -61,7 +61,7 @@ class BinaryClassificationModel ( } /** - * :: Deprecated :: + * DEPRECATED: Use predictScore(...) or predictClass(...) instead * Clears the threshold so that `predict` will output raw prediction scores. */ @Deprecated @@ -71,38 +71,50 @@ class BinaryClassificationModel ( } /** - * :: Deprecated :: + * DEPRECATED: Use predictScore(...) or predictClass(...) instead + */ + @Deprecated + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, + intercept: Double) = { + if (useThreshold) predictClass(dataMatrix) + else predictScore(dataMatrix) + } + + /** + * DEPRECATED: Use predictScore(...) or predictClass(...) instead * Predict values for the given data set using the model trained. * * @param testData RDD representing data points to be predicted * @return an RDD[Double] where each entry contains the corresponding prediction */ - @deprecated + @Deprecated override def predict(testData: RDD[Vector]): RDD[Double] = { if (useThreshold) predictClass(testData) else predictScore(testData) } /** - * :: Deprecated :: + * DEPRECATED: Use predictScore(...) or predictClass(...) instead * Predict values for a single data point using the model trained. * * @param testData array representing a single data point * @return predicted category from the trained model */ - @deprecated - def predict(testData: Vector): Double = { + @Deprecated + override def predict(testData: Vector): Double = { if (useThreshold) predictClass(testData) else predictScore(testData) } /** - * :: Deprecated :: + * DEPRECATED: Use predictScore(...) or predictClass(...) instead * Predict values for examples stored in a JavaRDD. * @param testData JavaRDD representing data points to be predicted * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction */ - @deprecated + @Deprecated def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 3cb657f1aed8..4125a310b008 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -46,7 +46,10 @@ class LogisticRegressionModel ( computeProbability(predictScore(testData)) } - @deprecated + /** + * DEPRECATED: Use predictProbability(...) or predictClass(...) instead + */ + @Deprecated override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double) = { if (useThreshold) predictClass(dataMatrix) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index a497bb53d5b1..6b50c3892aeb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -66,8 +66,15 @@ class NaiveBayesModel private[mllib] ( labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) } - @deprecated + /** + * DEPRECATED: Use predictClass(...) instead + */ + @Deprecated def predict(testData: RDD[Vector]): RDD[Double] = predictClass(testData) + + /** + * DEPRECATED: Use predictClass(...) instead + */ @deprecated def predict(testData: Vector): Double = predictClass(testData) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index c3e60f1e9b2c..7a0a4b11bb2e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -33,15 +33,6 @@ class SVMModel ( override val weights: Vector, override val intercept: Double) extends BinaryClassificationModel(weights, intercept) { - - @deprecated - override protected def predictPoint( - dataMatrix: Vector, - weightMatrix: Vector, - intercept: Double) = { - if (useThreshold) predictClass(dataMatrix) - else predictScore(dataMatrix) - } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index a9685c815554..d04d63c22b1f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -86,7 +86,7 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * @param weightMatrix Column vector containing the weights of the model * @param intercept Intercept of the model. */ - @deprecated + @Deprecated protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double /** @@ -96,7 +96,7 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * @param testData RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction */ - @deprecated + @Deprecated def predict(testData: RDD[Vector]): RDD[Double] = { // A small optimization to avoid serializing the entire model. Only the weightsMatrix // and intercept is needed. @@ -116,7 +116,7 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * @param testData array representing a single data point * @return Double prediction from the trained model */ - @deprecated + @Deprecated def predict(testData: Vector): Double = { predictPoint(testData, weights, intercept) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 912de6651aa7..66a637661900 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -34,12 +34,16 @@ class LassoModel ( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override protected def computeScore( - dataMatrix: Vector, - weightMatrix: Vector, - intercept: Double): Double = { - weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept - } + /** + * DEPRECATED: Use predictScore(...) instead + */ + @Deprecated + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, + intercept: Double): Double = + predictScore(dataMatrix) + } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 3867639dc919..5ac88e2a4cc9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -32,7 +32,11 @@ class LinearRegressionModel ( override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - @deprecated + + /** + * DEPRECATED: Use predictScore(...) instead + */ + @Deprecated override protected def predictPoint( dataMatrix: Vector, weightMatrix: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 7a3bd0c74ee4..b915e5dbeaca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -33,11 +33,14 @@ class RidgeRegressionModel ( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - @deprecated + /** + * DEPRECATED: Use predictScore(...) instead + */ + @Deprecated override protected def predictPoint( dataMatrix: Vector, weightMatrix: Vector, - intercept: Double) = + intercept: Double): Double = predictScore(dataMatrix) } From 39203398a346225538143e38b871100a7ce2e327 Mon Sep 17 00:00:00 2001 From: Christoph Sawade Date: Mon, 1 Sep 2014 19:05:16 +0200 Subject: [PATCH 5/5] WIP --- .../classification/BinaryClassificationModel.scala | 10 ---------- .../apache/spark/mllib/classification/NaiveBayes.scala | 8 ++++++++ .../mllib/regression/GeneralizedLinearAlgorithm.scala | 8 ++++++++ 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/BinaryClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/BinaryClassificationModel.scala index a382ce8c24c4..7582c18f3195 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/BinaryClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/BinaryClassificationModel.scala @@ -107,14 +107,4 @@ class BinaryClassificationModel ( if (useThreshold) predictClass(testData) else predictScore(testData) } - - /** - * DEPRECATED: Use predictScore(...) or predictClass(...) instead - * Predict values for examples stored in a JavaRDD. - * @param testData JavaRDD representing data points to be predicted - * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction - */ - @Deprecated - def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = - predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 6b50c3892aeb..3534e200843f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD +import org.apache.spark.api.java.JavaRDD /** * Model for Naive Bayes Classifiers. @@ -77,6 +78,13 @@ class NaiveBayesModel private[mllib] ( */ @deprecated def predict(testData: Vector): Double = predictClass(testData) + + /** + * DEPRECATED: Use predictClass(...) instead + */ + @Deprecated + def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = + predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index d04d63c22b1f..54b97384d981 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLUtils._ +import org.apache.spark.api.java.JavaRDD /** * :: DeveloperApi :: @@ -120,6 +121,13 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double def predict(testData: Vector): Double = { predictPoint(testData, weights, intercept) } + + /** + * DEPRECATED: Use predictScore(...) instead + */ + @Deprecated + def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = + predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } /**