diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 58815434cbdaf..9eac8ed22a3f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -62,6 +62,39 @@ private[ml] trait PredictorParams extends Params } SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) } + + /** + * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset, + * and put it in an RDD with strong types. + */ + protected def extractInstances(dataset: Dataset[_]): RDD[Instance] = { + val w = this match { + case p: HasWeightCol => + if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) { + col($(p.weightCol)).cast(DoubleType) + } else { + lit(1.0) + } + } + + dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } + } + + /** + * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset, + * and put it in an RDD with strong types. + * Validate the output instances with the given function. + */ + protected def extractInstances(dataset: Dataset[_], + validateInstance: Instance => Unit): RDD[Instance] = { + extractInstances(dataset).map { instance => + validateInstance(instance) + instance + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index b6b02e77909bd..9ac673078d4ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.{MetadataUtils, SchemaUtils} @@ -42,6 +42,22 @@ private[spark] trait ClassifierParams val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT) } + + /** + * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset, + * and put it in an RDD with strong types. + * Validates the label on the classifier is a valid integer in the range [0, numClasses). + */ + protected def extractInstances(dataset: Dataset[_], + numClasses: Int): RDD[Instance] = { + val validateInstance = (instance: Instance) => { + val label = instance.label + require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" + + s" dataset with invalid label $label. Labels must be integers in range" + + s" [0, $numClasses).") + } + extractInstances(dataset, validateInstance) + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 6bd8a26f5b1a8..2d0212f36fad4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -22,7 +22,7 @@ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since -import org.apache.spark.ml.feature.{Instance, LabeledPoint} +import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ @@ -34,9 +34,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.{col, lit, udf} -import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.functions.{col, udf} /** * Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning) @@ -116,9 +115,8 @@ class DecisionTreeClassifier @Since("1.4.0") ( dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr => instr.logPipelineStage(this) instr.logDataset(dataset) - val categoricalFeatures: Map[Int, Int] = - MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val numClasses: Int = getNumClasses(dataset) + val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val numClasses = getNumClasses(dataset) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -126,13 +124,7 @@ class DecisionTreeClassifier @Since("1.4.0") ( s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } validateNumClasses(numClasses) - val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - validateLabel(label, numClasses) - Instance(label, weight, features) - } + val instances = extractInstances(dataset, numClasses) val strategy = getOldStrategy(categoricalFeatures, numClasses) instr.logNumClasses(numClasses) instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol, diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 78503585261bf..e467228b4cc14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -36,9 +36,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Dataset, Row} -import org.apache.spark.sql.functions.{col, lit} /** Params for linear SVM Classifier. */ private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam @@ -161,12 +159,7 @@ class LinearSVC @Since("2.2.0") ( override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra) override protected def train(dataset: Dataset[_]): LinearSVCModel = instrumented { instr => - val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = - dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val instances = extractInstances(dataset) instr.logPipelineStage(this) instr.logDataset(dataset) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 0997c1e7b38d6..af6e2b39ecb60 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -40,9 +40,8 @@ import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, Multiclas import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.{DataType, DoubleType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils @@ -492,12 +491,7 @@ class LogisticRegression @Since("1.2.0") ( protected[spark] def train( dataset: Dataset[_], handlePersistence: Boolean): LogisticRegressionModel = instrumented { instr => - val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = - dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val instances = extractInstances(dataset) if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 47b8a8df637b9..41db6f3f44342 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} import org.apache.spark.ml.feature.OneHotEncoderModel -import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index e97af0582d358..205f565aa2685 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasWeightCol @@ -28,7 +29,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{Dataset, Row} -import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.functions.col /** * Params for Naive Bayes Classifiers. @@ -137,17 +138,14 @@ class NaiveBayes @Since("1.5.0") ( s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } - val modelTypeValue = $(modelType) - val requireValues: Vector => Unit = { - modelTypeValue match { - case Multinomial => - requireNonnegativeValues - case Bernoulli => - requireZeroOneBernoulliValues - case _ => - // This should never happen. - throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.") - } + val validateInstance = $(modelType) match { + case Multinomial => + (instance: Instance) => requireNonnegativeValues(instance.features) + case Bernoulli => + (instance: Instance) => requireZeroOneBernoulliValues(instance.features) + case _ => + // This should never happen. + throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.") } instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol, @@ -155,17 +153,15 @@ class NaiveBayes @Since("1.5.0") ( val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size instr.logNumFeatures(numFeatures) - val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) // Aggregates term frequencies per label. // TODO: Calling aggregateByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd - .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2))) - }.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))( + val aggregated = extractInstances(dataset, validateInstance).map { instance => + (instance.label, (instance.weight, instance.features)) + }.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))( seqOp = { case ((weightSum, featureSum, count), (weight, features)) => - requireValues(features) BLAS.axpy(weight, features, featureSum) (weightSum + weight, featureSum, count + 1) }, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 106be1b78af47..602b5fac20d3b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -23,7 +23,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.feature.{Instance, LabeledPoint} +import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ @@ -34,9 +34,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType /** @@ -118,12 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val instances = extractInstances(dataset) val strategy = getOldStrategy(categoricalFeatures) instr.logPipelineStage(this) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index abf75d70ea028..4c600eac26b37 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -43,7 +43,6 @@ import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel} import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -320,13 +319,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String override protected def train(dataset: Dataset[_]): LinearRegressionModel = instrumented { instr => // Extract the number of features before deciding optimization solver. val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size - val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = dataset.select( - col($(labelCol)), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val instances = extractInstances(dataset) instr.logPipelineStage(this) instr.logDataset(dataset)