-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19591][ML][MLlib] Add sample weights to decision trees #16722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,7 +31,7 @@ object TestingUtils { | |
| * Note that if x or y is extremely close to zero, i.e., smaller than Double.MinPositiveValue, | ||
| * the relative tolerance is meaningless, so the exception will be raised to warn users. | ||
| */ | ||
| private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = { | ||
| private[ml] def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = { | ||
| val absX = math.abs(x) | ||
| val absY = math.abs(y) | ||
| val diff = math.abs(x - y) | ||
|
|
@@ -48,7 +48,7 @@ object TestingUtils { | |
| /** | ||
| * Private helper function for comparing two values using absolute tolerance. | ||
| */ | ||
| private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = { | ||
| private[ml] def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see this used in the new code, maybe my search is not working properly in browser
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not used. I just changed the scope of both methods, I can change it back of course. I don't see a great reason to make this public since most users will use
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, I don't have a very strong opinion here either |
||
| math.abs(x - y) < eps | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,18 +22,21 @@ import org.json4s.{DefaultFormats, JObject} | |
| import org.json4s.JsonDSL._ | ||
|
|
||
| import org.apache.spark.annotation.Since | ||
| import org.apache.spark.ml.feature.LabeledPoint | ||
| import org.apache.spark.ml.feature.{Instance, LabeledPoint} | ||
| import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} | ||
| import org.apache.spark.ml.param.ParamMap | ||
| import org.apache.spark.ml.param.shared.HasWeightCol | ||
| import org.apache.spark.ml.tree._ | ||
| import org.apache.spark.ml.tree.{DecisionTreeModel, Node, TreeClassifierParams} | ||
| import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ | ||
| import org.apache.spark.ml.tree.impl.RandomForest | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} | ||
| import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.Dataset | ||
|
|
||
| import org.apache.spark.sql.{Dataset, Row} | ||
| import org.apache.spark.sql.functions.{col, lit} | ||
| import org.apache.spark.sql.types.DoubleType | ||
|
|
||
| /** | ||
| * Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning) | ||
|
|
@@ -65,6 +68,9 @@ class DecisionTreeClassifier @Since("1.4.0") ( | |
| override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value) | ||
|
|
||
| @Since("1.4.0") | ||
| override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) | ||
|
|
||
|
|
@@ -96,6 +102,16 @@ class DecisionTreeClassifier @Since("1.4.0") ( | |
| @Since("1.6.0") | ||
| override def setSeed(value: Long): this.type = set(seed, value) | ||
|
|
||
| /** | ||
| * Sets the value of param [[weightCol]]. | ||
| * If this is not set or empty, we treat all instance weights as 1.0. | ||
| * Default is not set, so all instances have weight one. | ||
| * | ||
| * @group setParam | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it looks like by removing this method call you are removing some valuable validation logic (that exists in the base class). require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. Actually this problem exists elsewhere (LogisticRegression, e.g.) What to do you think about adding it back manually here and then addressing the larger issue in a separate JIRA?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would say that's fine if it was only in one place, but I also see this pattern in DecisionTreeRegressor.scala, it seems like we should be able to refactor this part out
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For regressors,
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds reasonable, thanks for the explanation. |
||
| */ | ||
| @Since("2.2.0") | ||
| def setWeightCol(value: String): this.type = set(weightCol, value) | ||
|
|
||
| override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { | ||
| val categoricalFeatures: Map[Int, Int] = | ||
| MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) | ||
|
|
@@ -106,14 +122,23 @@ class DecisionTreeClassifier @Since("1.4.0") ( | |
| ".train() called with non-matching numClasses and thresholds.length." + | ||
| s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") | ||
| } | ||
|
|
||
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) | ||
| require(numClasses > 0, s"DecisionTreeClassifier (in extractLabeledPoints) found numClasses =" + | ||
| s" $numClasses, but requires numClasses > 0.") | ||
| val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) | ||
| val instances = | ||
| dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { | ||
| case Row(label: Double, weight: Double, features: Vector) => | ||
| require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" + | ||
| s" dataset with invalid label $label. Labels must be integers in range" + | ||
| s" [0, $numClasses).") | ||
| Instance(label, weight, features) | ||
| } | ||
| val strategy = getOldStrategy(categoricalFeatures, numClasses) | ||
|
|
||
| val instr = Instrumentation.create(this, oldDataset) | ||
| val instr = Instrumentation.create(this, instances) | ||
| instr.logParams(params: _*) | ||
|
|
||
| val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", | ||
| val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all", | ||
| seed = $(seed), instr = Some(instr), parentUID = Some(uid)) | ||
|
|
||
| val m = trees.head.asInstanceOf[DecisionTreeClassificationModel] | ||
|
|
@@ -124,11 +149,12 @@ class DecisionTreeClassifier @Since("1.4.0") ( | |
| /** (private[ml]) Train a decision tree on an RDD */ | ||
| private[ml] def train(data: RDD[LabeledPoint], | ||
| oldStrategy: OldStrategy): DecisionTreeClassificationModel = { | ||
| val instr = Instrumentation.create(this, data) | ||
| instr.logParams(params: _*) | ||
|
|
||
| val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", | ||
| seed = 0L, instr = Some(instr), parentUID = Some(uid)) | ||
| val instances = data.map(_.toInstance(1.0)) | ||
| val instr = Instrumentation.create(this, instances) | ||
| instr.logParams(params: _*) | ||
| val trees = RandomForest.run(instances, oldStrategy, numTrees = 1, | ||
| featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid)) | ||
|
|
||
| val m = trees.head.asInstanceOf[DecisionTreeClassificationModel] | ||
| instr.logSuccess(m) | ||
|
|
@@ -176,6 +202,7 @@ class DecisionTreeClassificationModel private[ml] ( | |
|
|
||
| /** | ||
| * Construct a decision tree classification model. | ||
| * | ||
| * @param rootNode Root node of tree, with other nodes attached. | ||
| */ | ||
| private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,19 +21,20 @@ import org.json4s.{DefaultFormats, JObject} | |
| import org.json4s.JsonDSL._ | ||
|
|
||
| import org.apache.spark.annotation.Since | ||
| import org.apache.spark.ml.feature.LabeledPoint | ||
| import org.apache.spark.ml.feature.Instance | ||
| import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} | ||
| import org.apache.spark.ml.param.ParamMap | ||
| import org.apache.spark.ml.tree._ | ||
| import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, TreeEnsembleModel} | ||
| import org.apache.spark.ml.tree.impl.RandomForest | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.ml.util.{Identifiable, MetadataUtils} | ||
| import org.apache.spark.ml.util.DefaultParamsReader.Metadata | ||
| import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} | ||
| import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
| import org.apache.spark.sql.functions._ | ||
|
|
||
| import org.apache.spark.sql.functions.{col, udf} | ||
|
|
||
| /** | ||
| * <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for | ||
|
|
@@ -126,20 +127,20 @@ class RandomForestClassifier @Since("1.4.0") ( | |
| s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") | ||
| } | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above, it looks like some validation logic is missing here |
||
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) | ||
| val instances: RDD[Instance] = extractLabeledPoints(dataset, numClasses).map(_.toInstance(1.0)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor simplification -
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update: since you removed the overload now this comment is no longer valid. |
||
| val strategy = | ||
| super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) | ||
|
|
||
| val instr = Instrumentation.create(this, oldDataset) | ||
| val instr = Instrumentation.create(this, instances) | ||
| instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol, | ||
| impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, | ||
| minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval) | ||
|
|
||
| val trees = RandomForest | ||
| .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) | ||
| .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) | ||
| .map(_.asInstanceOf[DecisionTreeClassificationModel]) | ||
|
|
||
| val numFeatures = oldDataset.first().features.size | ||
| val numFeatures = instances.first().features.size | ||
| val m = new RandomForestClassificationModel(trees, numFeatures, numClasses) | ||
| instr.logSuccess(m) | ||
| m | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,8 +23,9 @@ import org.json4s.JsonDSL._ | |
|
|
||
| import org.apache.spark.annotation.Since | ||
| import org.apache.spark.ml.{PredictionModel, Predictor} | ||
| import org.apache.spark.ml.feature.LabeledPoint | ||
| import org.apache.spark.ml.feature.{Instance, LabeledPoint} | ||
| import org.apache.spark.ml.linalg.Vector | ||
| import org.apache.spark.ml.param.shared.HasWeightCol | ||
| import org.apache.spark.ml.param.ParamMap | ||
| import org.apache.spark.ml.tree._ | ||
| import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ | ||
|
|
@@ -33,8 +34,10 @@ import org.apache.spark.ml.util._ | |
| import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} | ||
| import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
| import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.functions.{col, lit} | ||
| import org.apache.spark.sql.types.DoubleType | ||
|
|
||
|
|
||
| /** | ||
|
|
@@ -64,6 +67,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S | |
| override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value) | ||
|
|
||
| @Since("1.4.0") | ||
| override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) | ||
|
|
||
|
|
@@ -99,16 +105,31 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S | |
| @Since("2.0.0") | ||
| def setVarianceCol(value: String): this.type = set(varianceCol, value) | ||
|
|
||
| /** | ||
| * Sets the value of param [[weightCol]]. | ||
| * If this is not set or empty, we treat all instance weights as 1.0. | ||
| * Default is not set, so all instances have weight one. | ||
| * | ||
| * @group setParam | ||
| */ | ||
| @Since("2.2.0") | ||
| def setWeightCol(value: String): this.type = set(weightCol, value) | ||
|
|
||
| override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = { | ||
| val categoricalFeatures: Map[Int, Int] = | ||
| MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) | ||
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) | ||
| val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) | ||
| val instances = | ||
| dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { | ||
| case Row(label: Double, weight: Double, features: Vector) => | ||
| Instance(label, weight, features) | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the code above looks the same as the classifier, can we refactor somehow:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update: it sounds like you are going to create a separate JIRA for refactoring this code, that is reasonable to me. |
||
| val strategy = getOldStrategy(categoricalFeatures) | ||
|
|
||
| val instr = Instrumentation.create(this, oldDataset) | ||
| val instr = Instrumentation.create(this, instances) | ||
| instr.logParams(params: _*) | ||
|
|
||
| val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", | ||
| val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all", | ||
| seed = $(seed), instr = Some(instr), parentUID = Some(uid)) | ||
|
|
||
| val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] | ||
|
|
@@ -122,8 +143,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S | |
| val instr = Instrumentation.create(this, data) | ||
| instr.logParams(params: _*) | ||
|
|
||
| val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", | ||
| seed = $(seed), instr = Some(instr), parentUID = Some(uid)) | ||
| val instances = data.map(_.toInstance(1.0)) | ||
| val trees = RandomForest.run(instances, oldStrategy, numTrees = 1, | ||
| featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid)) | ||
|
|
||
| val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] | ||
| instr.logSuccess(m) | ||
|
|
@@ -153,6 +175,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor | |
| * <a href="http://en.wikipedia.org/wiki/Decision_tree_learning"> | ||
| * Decision tree (Wikipedia)</a> model for regression. | ||
| * It supports both continuous and categorical features. | ||
| * | ||
| * @param rootNode Root of the decision tree | ||
| */ | ||
| @Since("1.4.0") | ||
|
|
@@ -171,6 +194,7 @@ class DecisionTreeRegressionModel private[ml] ( | |
|
|
||
| /** | ||
| * Construct a decision tree regression model. | ||
| * | ||
| * @param rootNode Root node of tree, with other nodes attached. | ||
| */ | ||
| private[ml] def this(rootNode: Node, numFeatures: Int) = | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,6 @@ import org.json4s.JsonDSL._ | |
|
|
||
| import org.apache.spark.annotation.Since | ||
| import org.apache.spark.ml.{PredictionModel, Predictor} | ||
| import org.apache.spark.ml.feature.LabeledPoint | ||
| import org.apache.spark.ml.linalg.Vector | ||
| import org.apache.spark.ml.param.ParamMap | ||
| import org.apache.spark.ml.tree._ | ||
|
|
@@ -31,10 +30,8 @@ import org.apache.spark.ml.util._ | |
| import org.apache.spark.ml.util.DefaultParamsReader.Metadata | ||
| import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} | ||
| import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
| import org.apache.spark.sql.functions._ | ||
|
|
||
| import org.apache.spark.sql.functions.{col, udf} | ||
|
|
||
| /** | ||
| * <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> | ||
|
|
@@ -117,20 +114,20 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S | |
| override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = { | ||
| val categoricalFeatures: Map[Int, Int] = | ||
| MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) | ||
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) | ||
|
|
||
| val instances = extractLabeledPoints(dataset).map(_.toInstance(1.0)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. simplify to toInstance (without the 1.0) |
||
| val strategy = | ||
| super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) | ||
|
|
||
| val instr = Instrumentation.create(this, oldDataset) | ||
| val instr = Instrumentation.create(this, instances) | ||
| instr.logParams(labelCol, featuresCol, predictionCol, impurity, numTrees, | ||
| featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, | ||
| minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval) | ||
|
|
||
| val trees = RandomForest | ||
| .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) | ||
| .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) | ||
| .map(_.asInstanceOf[DecisionTreeRegressionModel]) | ||
|
|
||
| val numFeatures = oldDataset.first().features.size | ||
| val numFeatures = instances.first().features.size | ||
| val m = new RandomForestRegressionModel(trees, numFeatures) | ||
| instr.logSuccess(m) | ||
| m | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might it be better to just make this public, if we are using it in tests, similar to other test methods?