Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.ml

import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
Expand Down Expand Up @@ -62,6 +62,39 @@ private[ml] trait PredictorParams extends Params
}
SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
}

/**
* Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I place it in PredictorParam so that methods like GBTModel.evaluateEachIteration can reuse it in the future.

* and put it in an RDD with strong types.
*/
protected def extractInstances(dataset: Dataset[_]): RDD[Instance] = {
val w = this match {
case p: HasWeightCol =>
if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
col($(p.weightCol)).cast(DoubleType)
} else {
lit(1.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here too do you need a weight col, if the implementation doesn't support it (and shouldn't be calling this method)? or is it different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is different from the above place. Even if a ML impl supports weighting, its weightCol is not necessary to be set, in this case, lit(1) is used implictly. Current all algs supporting weighting deal with weightCol in this way.

}
}

dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
}

/**
* Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
* and put it in an RDD with strong types.
* Validate the output instances with the given function.
*/
protected def extractInstances(dataset: Dataset[_],
validateInstance: Instance => Unit): RDD[Instance] = {
extractInstances(dataset).map { instance =>
validateInstance(instance)
instance
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.{MetadataUtils, SchemaUtils}
Expand All @@ -42,6 +42,22 @@ private[spark] trait ClassifierParams
val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT)
}

/**
* Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
* and put it in an RDD with strong types.
* Validates the label on the classifier is a valid integer in the range [0, numClasses).
*/
protected def extractInstances(dataset: Dataset[_],
numClasses: Int): RDD[Instance] = {
val validateInstance = (instance: Instance) => {
val label = instance.label
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, $numClasses).")
}
extractInstances(dataset, validateInstance)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
Expand All @@ -34,9 +34,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit, udf}
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}

/**
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
Expand Down Expand Up @@ -116,23 +115,16 @@ class DecisionTreeClassifier @Since("1.4.0") (
dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr =>
instr.logPipelineStage(this)
instr.logDataset(dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses = getNumClasses(dataset)

if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
validateNumClasses(numClasses)
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
validateLabel(label, numClasses)
Instance(label, weight, features)
}
val instances = extractInstances(dataset, numClasses)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
instr.logNumClasses(numClasses)
instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}

/** Params for linear SVM Classifier. */
private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam
Expand Down Expand Up @@ -161,12 +159,7 @@ class LinearSVC @Since("2.2.0") (
override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra)

override protected def train(dataset: Dataset[_]): LinearSVCModel = instrumented { instr =>
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val instances = extractInstances(dataset)

instr.logPipelineStage(this)
instr.logDataset(dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, Multiclas
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils
Expand Down Expand Up @@ -492,12 +491,7 @@ class LogisticRegression @Since("1.2.0") (
protected[spark] def train(
dataset: Dataset[_],
handlePersistence: Boolean): LogisticRegressionModel = instrumented { instr =>
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val instances = extractInstances(dataset)

if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
import org.apache.spark.ml.feature.OneHotEncoderModel
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.Since
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.functions.col

/**
* Params for Naive Bayes Classifiers.
Expand Down Expand Up @@ -137,35 +138,30 @@ class NaiveBayes @Since("1.5.0") (
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}

val modelTypeValue = $(modelType)
val requireValues: Vector => Unit = {
modelTypeValue match {
case Multinomial =>
requireNonnegativeValues
case Bernoulli =>
requireZeroOneBernoulliValues
case _ =>
// This should never happen.
throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
}
val validateInstance = $(modelType) match {
case Multinomial =>
(instance: Instance) => requireNonnegativeValues(instance.features)
case Bernoulli =>
(instance: Instance) => requireZeroOneBernoulliValues(instance.features)
case _ =>
// This should never happen.
throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
}

instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol,
probabilityCol, modelType, smoothing, thresholds)

val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size
instr.logNumFeatures(numFeatures)
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))

// Aggregates term frequencies per label.
// TODO: Calling aggregateByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
.map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2)))
}.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))(
val aggregated = extractInstances(dataset, validateInstance).map { instance =>
(instance.label, (instance.weight, instance.features))
}.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))(
seqOp = {
case ((weightSum, featureSum, count), (weight, features)) =>
requireValues(features)
BLAS.axpy(weight, features, featureSum)
(weightSum + weight, featureSum, count + 1)
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
Expand All @@ -34,9 +34,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.{Column, DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType


/**
Expand Down Expand Up @@ -118,12 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr =>
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val instances = extractInstances(dataset)
val strategy = getOldStrategy(categoricalFeatures)

instr.logPipelineStage(this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel}
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
Expand Down Expand Up @@ -320,13 +319,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
override protected def train(dataset: Dataset[_]): LinearRegressionModel = instrumented { instr =>
// Extract the number of features before deciding optimization solver.
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))

val instances: RDD[Instance] = dataset.select(
col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val instances = extractInstances(dataset)

instr.logPipelineStage(this)
instr.logDataset(dataset)
Expand Down