-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-18239][SPARKR] Gradient Boosted Tree for R #15746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, | ||
| RandomForestParams, TreeRegressorParams, HasCheckpointInterval, | ||
| JavaMLWritable, JavaMLReadable, HasVarianceCol): | ||
| JavaMLWritable, JavaMLReadable): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was an erranous change - RandomForest does not have a variance column, unlike DecisionTree, so removing it
| featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, | ||
| probabilityCol = "probability", maxMemoryInMB = 256, cacheNodeIds = FALSE) { | ||
| minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, | ||
| maxMemoryInMB = 256, cacheNodeIds = FALSE, probabilityCol = "probability") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reordering parameter to match common/expert param types
|
Test build #68049 has finished for PR 15746 at commit
|
|
@mengxr @yanboliang Could you review this ? I'll try to take a look by end of this week. |
|
Sure, I can make a pass tomorrow. |
|
Test build #68072 has finished for PR 15746 at commit
|
R/pkg/R/mllib.R
Outdated
| #' Gradient Boosted Tree model, \code{predict} to make predictions on new data, and | ||
| #' \code{write.ml}/\code{read.ml} to save/load fitted models. | ||
| #' For more details, see | ||
| #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{GBT} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Directly link to http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier and http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression should be more clear?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought it's verbose to have 2 links, but I guess they are just links. Added.
| #' @param minInstancesPerNode Minimum number of instances each child must have after split. If a | ||
| #' split causes the left or right child to have fewer than | ||
| #' minInstancesPerNode, the split will be discarded as invalid. Should be | ||
| #' >= 1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(default = 1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was debating this. The other default text comes from Scala, I thought it would be nice to have one text but generally R doc text does not list the default value since it is clearly stated in the function signature right above on the Rd page.
So I'm removing all other "default = something" text unless they bring additional values (like explaining why).
This is the same in Python.
| #' split causes the left or right child to have fewer than | ||
| #' minInstancesPerNode, the split will be discarded as invalid. Should be | ||
| #' >= 1. | ||
| #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(default = 0.0)
| #' minInstancesPerNode, the split will be discarded as invalid. Should be | ||
| #' >= 1. | ||
| #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. | ||
| #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(default = 10)
| #' >= 1. | ||
| #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. | ||
| #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). | ||
| #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(default = 256)
| #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. | ||
| #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). | ||
| #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. | ||
| #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If TRUE, the algorithm will cache node IDs for each instance. (default = FALSE)
Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated.
| if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) | ||
|
|
||
| val pipeline = new Pipeline() | ||
| .setStages(Array(rFormulaModel, rfc)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think spark.gbt also need support to make binary classification based on dataset of string label such as Yes and No. This implementation will output double value when make prediction which may confuse users, and we should convert the double value back to the original string label. You can refer NaiveBayesWrapper to construct the pipeline. BTW, add R test for dataset of string label.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the existing test is string label, I'll add a test for numeric label
| val formula: String, | ||
| val features: Array[String]) extends MLWritable { | ||
|
|
||
| private val DTModel: GBTClassificationModel = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DTModel -> gbcModel should be better? The variable name should not start with an uppercase letter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed, thx, copy/paste mistake
| val formula: String, | ||
| val features: Array[String]) extends MLWritable { | ||
|
|
||
| private val DTModel: GBTRegressionModel = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
|
|
||
| # Prints the summary of Random Forest Regression Model | ||
| print.summary.randomForest <- function(x) { | ||
| print.summary.treeEnsemble <- function(x) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should not call toDebugString and output the detail structure of trees. These informations are used to debug and it's not easy to understand for R users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly. What would you suggest we show?
In R, generally the evaluated error should be show in summary, we don't really have that handy. Also I seems to recall an ongoing issue on the lack of consistency (or lack of information) to display to R user, and it has been suggested we should have helper functions on the model so we could be consistent across the board in all languages (as supposed to on the R side only like print.summary.GeneralizedLinearRegressionModel)?
I feel like there is a lot of work we could be doing here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, summary on rpart model shows both error and node-by-node information. I think it is still useful this way
Call:
rpart(formula = Kyphosis ~ Age + Number + Start, data = kyphosis,
method = "class")
n= 81
CP nsplit rel error xerror xstd
1 0.17647059 0 1.0000000 1.0000000 0.2155872
2 0.01960784 1 0.8235294 0.9411765 0.2107780
3 0.01000000 4 0.7647059 1.0588235 0.2200975
Variable importance
Start Age Number
64 24 12
Node number 1: 81 observations, complexity param=0.1764706
predicted class=absent expected loss=0.2098765 P(node) =1
class counts: 64 17
probabilities: 0.790 0.210
left son=2 (62 obs) right son=3 (19 obs)
Primary splits:
Start < 8.5 to the right, improve=6.762330, (0 missing)
Number < 5.5 to the left, improve=2.866795, (0 missing)
Age < 39.5 to the left, improve=2.250212, (0 missing)
Surrogate splits:
Number < 6.5 to the left, agree=0.802, adj=0.158, (0 split)
Node number 2: 62 observations, complexity param=0.01960784
predicted class=absent expected loss=0.09677419 P(node) =0.7654321
class counts: 56 6
probabilities: 0.903 0.097
left son=4 (29 obs) right son=5 (33 obs)
Primary splits:
Start < 14.5 to the right, improve=1.0205280, (0 missing)
Age < 55 to the left, improve=0.6848635, (0 missing)
Number < 4.5 to the left, improve=0.2975332, (0 missing)
Surrogate splits:
Number < 3.5 to the left, agree=0.645, adj=0.241, (0 split)
Age < 16 to the left, agree=0.597, adj=0.138, (0 split)
Node number 3: 19 observations
predicted class=present expected loss=0.4210526 P(node) =0.2345679
class counts: 8 11
probabilities: 0.421 0.579
Node number 4: 29 observations
predicted class=absent expected loss=0 P(node) =0.3580247
class counts: 29 0
probabilities: 1.000 0.000
Node number 5: 33 observations, complexity param=0.01960784
predicted class=absent expected loss=0.1818182 P(node) =0.4074074
class counts: 27 6
probabilities: 0.818 0.182
left son=10 (12 obs) right son=11 (21 obs)
Primary splits:
Age < 55 to the left, improve=1.2467530, (0 missing)
Start < 12.5 to the right, improve=0.2887701, (0 missing)
Number < 3.5 to the right, improve=0.1753247, (0 missing)
Surrogate splits:
Start < 9.5 to the left, agree=0.758, adj=0.333, (0 split)
Number < 5.5 to the right, agree=0.697, adj=0.167, (0 split)
Node number 10: 12 observations
predicted class=absent expected loss=0 P(node) =0.1481481
class counts: 12 0
probabilities: 1.000 0.000
Node number 11: 21 observations, complexity param=0.01960784
predicted class=absent expected loss=0.2857143 P(node) =0.2592593
class counts: 15 6
probabilities: 0.714 0.286
left son=22 (14 obs) right son=23 (7 obs)
Primary splits:
Age < 111 to the right, improve=1.71428600, (0 missing)
Start < 12.5 to the right, improve=0.79365080, (0 missing)
Number < 3.5 to the right, improve=0.07142857, (0 missing)
Node number 22: 14 observations
predicted class=absent expected loss=0.1428571 P(node) =0.1728395
class counts: 12 2
probabilities: 0.857 0.143
Node number 23: 7 observations
predicted class=present expected loss=0.4285714 P(node) =0.08641975
class counts: 3 4
probabilities: 0.429 0.571
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's OK. I thought the output may be very large which will flood the screen. The output of toDebugString is also not very legible compared with rpart. I like the idea to make summary string consistent between languages. Let's get this in firstly and improve toDebugString at Scala side in a separate task which can also benefit SparkR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll open a JIRA on that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
opened SPARK-18348
…ic label (force index label & predicted label to string), tests
|
Test build #68166 has finished for PR 15746 at commit
|
|
any more thought on this? |
| iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) | ||
| df <- suppressWarnings(createDataFrame(iris2)) | ||
| m <- spark.gbt(df, NumericSpecies ~ ., type = "classification") | ||
| s <- summary(m) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like we never use this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a test, but this is mostly to make sure the call is not failing
| 68.655, 69.564, 69.331, 70.551), | ||
| tolerance = 1e-4) | ||
| stats <- summary(model) | ||
| expect_equal(stats$numTrees, 20) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why only check numTrees? I think we should also check numFeatures, featureImportances and treeWeights at least. Any thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added. featureImportances is a bit tricky - in JVM it's a Vector and doesn't translate to something accessible in R (see SPARK-18226)
so for now featureImpoartances is converted to a string, and let's skip testing that for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Since there is no object represents ML Vector in SparkR currently, I'd like to convert the type of featureImportances from Vector to Array at GBTClassifierWrapper.scala.
lazy val featureImportances: Array = gbtcModel.featureImportances.toArray
Then it can be translated to R list. Users may sort or select the feature importances, so return as R list should make more sense. Any thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I tried that and it was really a SparseVector so converting to an Array made it fairly unreadable and unusable.
I think SparseVector should really map to a Map or a Properties.
| function(object, path, overwrite = FALSE) { | ||
| write_internal(object, path, overwrite) | ||
| }) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps add a line of annotation: Get the summary of a GBTRegressionModel model. I know it will not appear in R doc, it was used for developers to understand the code.
R/pkg/R/mllib.R
Outdated
| }) | ||
|
|
||
| #' @return \code{summary} returns the model's features as lists, depth and number of nodes | ||
| #' or number of classes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we clarify the return values more clear, such as feature importance, tree weights, number of trees, etc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated. I took a shot at updating other models but we have a lot of issues with details and consistency across all other ml models - I'll open a JIRA to track.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
opened SPARK-18349
|
|
||
| # Prints the summary of Random Forest Regression Model | ||
| print.summary.randomForest <- function(x) { | ||
| print.summary.treeEnsemble <- function(x) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's OK. I thought the output may be very large which will flood the screen. The output of toDebugString is also not very legible compared with rpart. I like the idea to make summary string consistent between languages. Let's get this in firstly and improve toDebugString at Scala side in a separate task which can also benefit SparkR.
|
@felixcheung I made another pass and left some minor comments, otherwise, looks good to me. Thanks. |
|
Test build #68295 has finished for PR 15746 at commit
|
|
Jenkins, retest this please |
|
Test build #68302 has finished for PR 15746 at commit
|
|
The best sparse vector support in R comes from the |
|
That's a great suggestion. I've added to SPARK-18131 |
|
merged to master and branch-2.1 |
## What changes were proposed in this pull request? Gradient Boosted Tree in R. With a few minor improvements to RandomForest in R. Since this is relatively isolated I'd like to target this for branch-2.1 ## How was this patch tested? manual tests, unit tests Author: Felix Cheung <felixcheung_m@hotmail.com> Closes #15746 from felixcheung/rgbt. (cherry picked from commit 55964c1) Signed-off-by: Felix Cheung <felixcheung@apache.org>
## What changes were proposed in this pull request? Gradient Boosted Tree in R. With a few minor improvements to RandomForest in R. Since this is relatively isolated I'd like to target this for branch-2.1 ## How was this patch tested? manual tests, unit tests Author: Felix Cheung <felixcheung_m@hotmail.com> Closes apache#15746 from felixcheung/rgbt.
What changes were proposed in this pull request?
Gradient Boosted Tree in R.
With a few minor improvements to RandomForest in R.
Since this is relatively isolated I'd like to target this for branch-2.1
How was this patch tested?
manual tests, unit tests