Skip to content

Commit fff2e84

Browse files
zhengruifengsrowen
authored andcommitted
[SPARK-29095][ML] add extractInstances
### What changes were proposed in this pull request? common methods support extract weights ### Why are the changes needed? today more and more ML algs support weighting, add this method will make impls simple ### Does this PR introduce any user-facing change? no ### How was this patch tested? existing testsuites Closes #25802 from zhengruifeng/add_extractInstances. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
1 parent 7c02c14 commit fff2e84

File tree

9 files changed

+78
-66
lines changed

9 files changed

+78
-66
lines changed

mllib/src/main/scala/org/apache/spark/ml/Predictor.scala

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.ml
1919

2020
import org.apache.spark.annotation.{DeveloperApi, Since}
21-
import org.apache.spark.ml.feature.LabeledPoint
21+
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
2222
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
2323
import org.apache.spark.ml.param._
2424
import org.apache.spark.ml.param.shared._
@@ -62,6 +62,39 @@ private[ml] trait PredictorParams extends Params
6262
}
6363
SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
6464
}
65+
66+
/**
67+
* Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
68+
* and put it in an RDD with strong types.
69+
*/
70+
protected def extractInstances(dataset: Dataset[_]): RDD[Instance] = {
71+
val w = this match {
72+
case p: HasWeightCol =>
73+
if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
74+
col($(p.weightCol)).cast(DoubleType)
75+
} else {
76+
lit(1.0)
77+
}
78+
}
79+
80+
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
81+
case Row(label: Double, weight: Double, features: Vector) =>
82+
Instance(label, weight, features)
83+
}
84+
}
85+
86+
/**
87+
* Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
88+
* and put it in an RDD with strong types.
89+
* Validate the output instances with the given function.
90+
*/
91+
protected def extractInstances(dataset: Dataset[_],
92+
validateInstance: Instance => Unit): RDD[Instance] = {
93+
extractInstances(dataset).map { instance =>
94+
validateInstance(instance)
95+
instance
96+
}
97+
}
6598
}
6699

67100
/**

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
2020
import org.apache.spark.SparkException
2121
import org.apache.spark.annotation.DeveloperApi
2222
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
23-
import org.apache.spark.ml.feature.LabeledPoint
23+
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
2424
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
2525
import org.apache.spark.ml.param.shared.HasRawPredictionCol
2626
import org.apache.spark.ml.util.{MetadataUtils, SchemaUtils}
@@ -42,6 +42,22 @@ private[spark] trait ClassifierParams
4242
val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
4343
SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT)
4444
}
45+
46+
/**
47+
* Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
48+
* and put it in an RDD with strong types.
49+
* Validates the label on the classifier is a valid integer in the range [0, numClasses).
50+
*/
51+
protected def extractInstances(dataset: Dataset[_],
52+
numClasses: Int): RDD[Instance] = {
53+
val validateInstance = (instance: Instance) => {
54+
val label = instance.label
55+
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
56+
s" dataset with invalid label $label. Labels must be integers in range" +
57+
s" [0, $numClasses).")
58+
}
59+
extractInstances(dataset, validateInstance)
60+
}
4561
}
4662

4763
/**

mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.json4s.{DefaultFormats, JObject}
2222
import org.json4s.JsonDSL._
2323

2424
import org.apache.spark.annotation.Since
25-
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
25+
import org.apache.spark.ml.feature.LabeledPoint
2626
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
2727
import org.apache.spark.ml.param.ParamMap
2828
import org.apache.spark.ml.tree._
@@ -34,9 +34,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
3434
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
3535
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
3636
import org.apache.spark.rdd.RDD
37-
import org.apache.spark.sql.{DataFrame, Dataset, Row}
38-
import org.apache.spark.sql.functions.{col, lit, udf}
39-
import org.apache.spark.sql.types.DoubleType
37+
import org.apache.spark.sql.{DataFrame, Dataset}
38+
import org.apache.spark.sql.functions.{col, udf}
4039

4140
/**
4241
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
@@ -116,23 +115,16 @@ class DecisionTreeClassifier @Since("1.4.0") (
116115
dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr =>
117116
instr.logPipelineStage(this)
118117
instr.logDataset(dataset)
119-
val categoricalFeatures: Map[Int, Int] =
120-
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
121-
val numClasses: Int = getNumClasses(dataset)
118+
val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
119+
val numClasses = getNumClasses(dataset)
122120

123121
if (isDefined(thresholds)) {
124122
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
125123
".train() called with non-matching numClasses and thresholds.length." +
126124
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
127125
}
128126
validateNumClasses(numClasses)
129-
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
130-
val instances =
131-
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
132-
case Row(label: Double, weight: Double, features: Vector) =>
133-
validateLabel(label, numClasses)
134-
Instance(label, weight, features)
135-
}
127+
val instances = extractInstances(dataset, numClasses)
136128
val strategy = getOldStrategy(categoricalFeatures, numClasses)
137129
instr.logNumClasses(numClasses)
138130
instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,

mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ import org.apache.spark.ml.util._
3636
import org.apache.spark.ml.util.Instrumentation.instrumented
3737
import org.apache.spark.mllib.linalg.VectorImplicits._
3838
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
39-
import org.apache.spark.rdd.RDD
4039
import org.apache.spark.sql.{Dataset, Row}
41-
import org.apache.spark.sql.functions.{col, lit}
4240

4341
/** Params for linear SVM Classifier. */
4442
private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam
@@ -161,12 +159,7 @@ class LinearSVC @Since("2.2.0") (
161159
override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra)
162160

163161
override protected def train(dataset: Dataset[_]): LinearSVCModel = instrumented { instr =>
164-
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
165-
val instances: RDD[Instance] =
166-
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
167-
case Row(label: Double, weight: Double, features: Vector) =>
168-
Instance(label, weight, features)
169-
}
162+
val instances = extractInstances(dataset)
170163

171164
instr.logPipelineStage(this)
172165
instr.logDataset(dataset)

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, Multiclas
4040
import org.apache.spark.mllib.linalg.VectorImplicits._
4141
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
4242
import org.apache.spark.mllib.util.MLUtils
43-
import org.apache.spark.rdd.RDD
4443
import org.apache.spark.sql.{DataFrame, Dataset, Row}
45-
import org.apache.spark.sql.functions.{col, lit}
44+
import org.apache.spark.sql.functions.col
4645
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
4746
import org.apache.spark.storage.StorageLevel
4847
import org.apache.spark.util.VersionUtils
@@ -492,12 +491,7 @@ class LogisticRegression @Since("1.2.0") (
492491
protected[spark] def train(
493492
dataset: Dataset[_],
494493
handlePersistence: Boolean): LogisticRegressionModel = instrumented { instr =>
495-
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
496-
val instances: RDD[Instance] =
497-
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
498-
case Row(label: Double, weight: Double, features: Vector) =>
499-
Instance(label, weight, features)
500-
}
494+
val instances = extractInstances(dataset)
501495

502496
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
503497

mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path
2424
import org.apache.spark.annotation.Since
2525
import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
2626
import org.apache.spark.ml.feature.OneHotEncoderModel
27-
import org.apache.spark.ml.linalg.{Vector, Vectors}
27+
import org.apache.spark.ml.linalg.Vector
2828
import org.apache.spark.ml.param._
2929
import org.apache.spark.ml.param.shared._
3030
import org.apache.spark.ml.util._

mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@ import org.apache.hadoop.fs.Path
2121

2222
import org.apache.spark.annotation.Since
2323
import org.apache.spark.ml.PredictorParams
24+
import org.apache.spark.ml.feature.Instance
2425
import org.apache.spark.ml.linalg._
2526
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
2627
import org.apache.spark.ml.param.shared.HasWeightCol
2728
import org.apache.spark.ml.util._
2829
import org.apache.spark.ml.util.Instrumentation.instrumented
2930
import org.apache.spark.mllib.util.MLUtils
3031
import org.apache.spark.sql.{Dataset, Row}
31-
import org.apache.spark.sql.functions.{col, lit}
32+
import org.apache.spark.sql.functions.col
3233

3334
/**
3435
* Params for Naive Bayes Classifiers.
@@ -137,35 +138,30 @@ class NaiveBayes @Since("1.5.0") (
137138
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
138139
}
139140

140-
val modelTypeValue = $(modelType)
141-
val requireValues: Vector => Unit = {
142-
modelTypeValue match {
143-
case Multinomial =>
144-
requireNonnegativeValues
145-
case Bernoulli =>
146-
requireZeroOneBernoulliValues
147-
case _ =>
148-
// This should never happen.
149-
throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
150-
}
141+
val validateInstance = $(modelType) match {
142+
case Multinomial =>
143+
(instance: Instance) => requireNonnegativeValues(instance.features)
144+
case Bernoulli =>
145+
(instance: Instance) => requireZeroOneBernoulliValues(instance.features)
146+
case _ =>
147+
// This should never happen.
148+
throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
151149
}
152150

153151
instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol,
154152
probabilityCol, modelType, smoothing, thresholds)
155153

156154
val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size
157155
instr.logNumFeatures(numFeatures)
158-
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
159156

160157
// Aggregates term frequencies per label.
161158
// TODO: Calling aggregateByKey and collect creates two stages, we can implement something
162159
// TODO: similar to reduceByKeyLocally to save one stage.
163-
val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
164-
.map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2)))
165-
}.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))(
160+
val aggregated = extractInstances(dataset, validateInstance).map { instance =>
161+
(instance.label, (instance.weight, instance.features))
162+
}.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))(
166163
seqOp = {
167164
case ((weightSum, featureSum, count), (weight, features)) =>
168-
requireValues(features)
169165
BLAS.axpy(weight, features, featureSum)
170166
(weightSum + weight, featureSum, count + 1)
171167
},

mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.json4s.JsonDSL._
2323

2424
import org.apache.spark.annotation.Since
2525
import org.apache.spark.ml.{PredictionModel, Predictor}
26-
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
26+
import org.apache.spark.ml.feature.LabeledPoint
2727
import org.apache.spark.ml.linalg.Vector
2828
import org.apache.spark.ml.param.ParamMap
2929
import org.apache.spark.ml.tree._
@@ -34,9 +34,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
3434
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
3535
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
3636
import org.apache.spark.rdd.RDD
37-
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
37+
import org.apache.spark.sql.{Column, DataFrame, Dataset}
3838
import org.apache.spark.sql.functions._
39-
import org.apache.spark.sql.types.DoubleType
4039

4140

4241
/**
@@ -118,12 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
118117
dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr =>
119118
val categoricalFeatures: Map[Int, Int] =
120119
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
121-
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
122-
val instances =
123-
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
124-
case Row(label: Double, weight: Double, features: Vector) =>
125-
Instance(label, weight, features)
126-
}
120+
val instances = extractInstances(dataset)
127121
val strategy = getOldStrategy(categoricalFeatures)
128122

129123
instr.logPipelineStage(this)

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ import org.apache.spark.mllib.linalg.VectorImplicits._
4343
import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel}
4444
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
4545
import org.apache.spark.mllib.util.MLUtils
46-
import org.apache.spark.rdd.RDD
4746
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
4847
import org.apache.spark.sql.functions._
4948
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
@@ -320,13 +319,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
320319
override protected def train(dataset: Dataset[_]): LinearRegressionModel = instrumented { instr =>
321320
// Extract the number of features before deciding optimization solver.
322321
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
323-
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
324322

325-
val instances: RDD[Instance] = dataset.select(
326-
col($(labelCol)), w, col($(featuresCol))).rdd.map {
327-
case Row(label: Double, weight: Double, features: Vector) =>
328-
Instance(label, weight, features)
329-
}
323+
val instances = extractInstances(dataset)
330324

331325
instr.logPipelineStage(this)
332326
instr.logDataset(dataset)

0 commit comments

Comments
 (0)