Skip to content

Conversation

@felixcheung
Copy link
Member

@felixcheung felixcheung commented Nov 3, 2016

What changes were proposed in this pull request?

Gradient Boosted Tree in R.
With a few minor improvements to RandomForest in R.

Since this is relatively isolated I'd like to target this for branch-2.1

How was this patch tested?

manual tests, unit tests

class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
RandomForestParams, TreeRegressorParams, HasCheckpointInterval,
JavaMLWritable, JavaMLReadable, HasVarianceCol):
JavaMLWritable, JavaMLReadable):
Copy link
Member Author

@felixcheung felixcheung Nov 3, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was an erranous change - RandomForest does not have a variance column, unlike DecisionTree, so removing it

featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0,
probabilityCol = "probability", maxMemoryInMB = 256, cacheNodeIds = FALSE) {
minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
maxMemoryInMB = 256, cacheNodeIds = FALSE, probabilityCol = "probability") {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reordering parameter to match common/expert param types

@SparkQA
Copy link

SparkQA commented Nov 3, 2016

Test build #68049 has finished for PR 15746 at commit fc8bbe3.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class GBTClassifierWrapperWriter(instance: GBTClassifierWrapper)
    • class GBTClassifierWrapperReader extends MLReader[GBTClassifierWrapper]
    • class GBTRegressorWrapperWriter(instance: GBTRegressorWrapper)
    • class GBTRegressorWrapperReader extends MLReader[GBTRegressorWrapper]

@shivaram
Copy link
Contributor

shivaram commented Nov 3, 2016

@mengxr @yanboliang Could you review this ? I'll try to take a look by end of this week.

@yanboliang
Copy link
Contributor

Sure, I can make a pass tomorrow.

@SparkQA
Copy link

SparkQA commented Nov 3, 2016

Test build #68072 has finished for PR 15746 at commit 1a317ac.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

R/pkg/R/mllib.R Outdated
#' Gradient Boosted Tree model, \code{predict} to make predictions on new data, and
#' \code{write.ml}/\code{read.ml} to save/load fitted models.
#' For more details, see
#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{GBT}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Directly link to http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier and http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression should be more clear?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it's verbose to have 2 links, but I guess they are just links. Added.

#' @param minInstancesPerNode Minimum number of instances each child must have after split. If a
#' split causes the left or right child to have fewer than
#' minInstancesPerNode, the split will be discarded as invalid. Should be
#' >= 1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(default = 1)

Copy link
Member Author

@felixcheung felixcheung Nov 4, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was debating this. The other default text comes from Scala, I thought it would be nice to have one text but generally R doc text does not list the default value since it is clearly stated in the function signature right above on the Rd page.

So I'm removing all other "default = something" text unless they bring additional values (like explaining why).

This is the same in Python.

#' split causes the left or right child to have fewer than
#' minInstancesPerNode, the split will be discarded as invalid. Should be
#' >= 1.
#' @param minInfoGain Minimum information gain for a split to be considered at a tree node.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(default = 0.0)

#' minInstancesPerNode, the split will be discarded as invalid. Should be
#' >= 1.
#' @param minInfoGain Minimum information gain for a split to be considered at a tree node.
#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(default = 10)

#' >= 1.
#' @param minInfoGain Minimum information gain for a split to be considered at a tree node.
#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(default = 256)

#' @param minInfoGain Minimum information gain for a split to be considered at a tree node.
#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation.
#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If TRUE, the algorithm will cache node IDs for each instance. (default = FALSE)
Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated.

if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong)

val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, rfc))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think spark.gbt also need support to make binary classification based on dataset of string label such as Yes and No. This implementation will output double value when make prediction which may confuse users, and we should convert the double value back to the original string label. You can refer NaiveBayesWrapper to construct the pipeline. BTW, add R test for dataset of string label.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the existing test is string label, I'll add a test for numeric label

val formula: String,
val features: Array[String]) extends MLWritable {

private val DTModel: GBTClassificationModel =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DTModel -> gbcModel should be better? The variable name should not start with an uppercase letter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, thx, copy/paste mistake

val formula: String,
val features: Array[String]) extends MLWritable {

private val DTModel: GBTRegressionModel =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.


# Prints the summary of Random Forest Regression Model
print.summary.randomForest <- function(x) {
print.summary.treeEnsemble <- function(x) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should not call toDebugString and output the detail structure of trees. These informations are used to debug and it's not easy to understand for R users.

Copy link
Member Author

@felixcheung felixcheung Nov 4, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly. What would you suggest we show?
In R, generally the evaluated error should be show in summary, we don't really have that handy. Also I seems to recall an ongoing issue on the lack of consistency (or lack of information) to display to R user, and it has been suggested we should have helper functions on the model so we could be consistent across the board in all languages (as supposed to on the R side only like print.summary.GeneralizedLinearRegressionModel)?

I feel like there is a lot of work we could be doing here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, summary on rpart model shows both error and node-by-node information. I think it is still useful this way

Call:
rpart(formula = Kyphosis ~ Age + Number + Start, data = kyphosis,
    method = "class")
  n= 81

          CP nsplit rel error    xerror      xstd
1 0.17647059      0 1.0000000 1.0000000 0.2155872
2 0.01960784      1 0.8235294 0.9411765 0.2107780
3 0.01000000      4 0.7647059 1.0588235 0.2200975

Variable importance
 Start    Age Number
    64     24     12

Node number 1: 81 observations,    complexity param=0.1764706
  predicted class=absent   expected loss=0.2098765  P(node) =1
    class counts:    64    17
   probabilities: 0.790 0.210
  left son=2 (62 obs) right son=3 (19 obs)
  Primary splits:
      Start  < 8.5  to the right, improve=6.762330, (0 missing)
      Number < 5.5  to the left,  improve=2.866795, (0 missing)
      Age    < 39.5 to the left,  improve=2.250212, (0 missing)
  Surrogate splits:
      Number < 6.5  to the left,  agree=0.802, adj=0.158, (0 split)

Node number 2: 62 observations,    complexity param=0.01960784
  predicted class=absent   expected loss=0.09677419  P(node) =0.7654321
    class counts:    56     6
   probabilities: 0.903 0.097
  left son=4 (29 obs) right son=5 (33 obs)
  Primary splits:
      Start  < 14.5 to the right, improve=1.0205280, (0 missing)
      Age    < 55   to the left,  improve=0.6848635, (0 missing)
      Number < 4.5  to the left,  improve=0.2975332, (0 missing)
  Surrogate splits:
      Number < 3.5  to the left,  agree=0.645, adj=0.241, (0 split)
      Age    < 16   to the left,  agree=0.597, adj=0.138, (0 split)

Node number 3: 19 observations
  predicted class=present  expected loss=0.4210526  P(node) =0.2345679
    class counts:     8    11
   probabilities: 0.421 0.579

Node number 4: 29 observations
  predicted class=absent   expected loss=0  P(node) =0.3580247
    class counts:    29     0
   probabilities: 1.000 0.000

Node number 5: 33 observations,    complexity param=0.01960784
  predicted class=absent   expected loss=0.1818182  P(node) =0.4074074
    class counts:    27     6
   probabilities: 0.818 0.182
  left son=10 (12 obs) right son=11 (21 obs)
  Primary splits:
      Age    < 55   to the left,  improve=1.2467530, (0 missing)
      Start  < 12.5 to the right, improve=0.2887701, (0 missing)
      Number < 3.5  to the right, improve=0.1753247, (0 missing)
  Surrogate splits:
      Start  < 9.5  to the left,  agree=0.758, adj=0.333, (0 split)
      Number < 5.5  to the right, agree=0.697, adj=0.167, (0 split)

Node number 10: 12 observations
  predicted class=absent   expected loss=0  P(node) =0.1481481
    class counts:    12     0
   probabilities: 1.000 0.000

Node number 11: 21 observations,    complexity param=0.01960784
  predicted class=absent   expected loss=0.2857143  P(node) =0.2592593
    class counts:    15     6
   probabilities: 0.714 0.286
  left son=22 (14 obs) right son=23 (7 obs)
  Primary splits:
      Age    < 111  to the right, improve=1.71428600, (0 missing)
      Start  < 12.5 to the right, improve=0.79365080, (0 missing)
      Number < 3.5  to the right, improve=0.07142857, (0 missing)

Node number 22: 14 observations
  predicted class=absent   expected loss=0.1428571  P(node) =0.1728395
    class counts:    12     2
   probabilities: 0.857 0.143

Node number 23: 7 observations
  predicted class=present  expected loss=0.4285714  P(node) =0.08641975
    class counts:     3     4
   probabilities: 0.429 0.571

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's OK. I thought the output may be very large which will flood the screen. The output of toDebugString is also not very legible compared with rpart. I like the idea to make summary string consistent between languages. Let's get this in firstly and improve toDebugString at Scala side in a separate task which can also benefit SparkR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll open a JIRA on that

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

opened SPARK-18348

…ic label (force index label & predicted label to string), tests
@SparkQA
Copy link

SparkQA commented Nov 5, 2016

Test build #68166 has finished for PR 15746 at commit 94bdf73.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@felixcheung
Copy link
Member Author

any more thought on this?

iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1)
df <- suppressWarnings(createDataFrame(iris2))
m <- spark.gbt(df, NumericSpecies ~ ., type = "classification")
s <- summary(m)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we never use this line?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a test, but this is mostly to make sure the call is not failing

68.655, 69.564, 69.331, 70.551),
tolerance = 1e-4)
stats <- summary(model)
expect_equal(stats$numTrees, 20)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only check numTrees? I think we should also check numFeatures, featureImportances and treeWeights at least. Any thoughts?

Copy link
Member Author

@felixcheung felixcheung Nov 7, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added. featureImportances is a bit tricky - in JVM it's a Vector and doesn't translate to something accessible in R (see SPARK-18226)

so for now featureImpoartances is converted to a string, and let's skip testing that for now

Copy link
Contributor

@yanboliang yanboliang Nov 8, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Since there is no object represents ML Vector in SparkR currently, I'd like to convert the type of featureImportances from Vector to Array at GBTClassifierWrapper.scala.

lazy val featureImportances: Array = gbtcModel.featureImportances.toArray

Then it can be translated to R list. Users may sort or select the feature importances, so return as R list should make more sense. Any thoughts?

Copy link
Member Author

@felixcheung felixcheung Nov 8, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I tried that and it was really a SparseVector so converting to an Array made it fairly unreadable and unusable.

I think SparseVector should really map to a Map or a Properties.

function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add a line of annotation: Get the summary of a GBTRegressionModel model. I know it will not appear in R doc, it was used for developers to understand the code.

R/pkg/R/mllib.R Outdated
})

#' @return \code{summary} returns the model's features as lists, depth and number of nodes
#' or number of classes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we clarify the return values more clear, such as feature importance, tree weights, number of trees, etc?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated. I took a shot at updating other models but we have a lot of issues with details and consistency across all other ml models - I'll open a JIRA to track.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

opened SPARK-18349


# Prints the summary of Random Forest Regression Model
print.summary.randomForest <- function(x) {
print.summary.treeEnsemble <- function(x) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's OK. I thought the output may be very large which will flood the screen. The output of toDebugString is also not very legible compared with rpart. I like the idea to make summary string consistent between languages. Let's get this in firstly and improve toDebugString at Scala side in a separate task which can also benefit SparkR.

@yanboliang
Copy link
Contributor

@felixcheung I made another pass and left some minor comments, otherwise, looks good to me. Thanks.

@SparkQA
Copy link

SparkQA commented Nov 7, 2016

Test build #68295 has finished for PR 15746 at commit af400bc.

  • This patch fails SparkR unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@felixcheung
Copy link
Member Author

Jenkins, retest this please

@SparkQA
Copy link

SparkQA commented Nov 7, 2016

Test build #68302 has finished for PR 15746 at commit af400bc.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@shivaram
Copy link
Contributor

shivaram commented Nov 8, 2016

The best sparse vector support in R comes from the Matrix package - But its a big package and I dont think we should add that as a dependency. We could try to do a wrapper where if the user already has the package installed we return it using Matrix ?

@felixcheung
Copy link
Member Author

felixcheung commented Nov 8, 2016

That's a great suggestion. I've added to SPARK-18131

@felixcheung
Copy link
Member Author

merged to master and branch-2.1

asfgit pushed a commit that referenced this pull request Nov 9, 2016
## What changes were proposed in this pull request?

Gradient Boosted Tree in R.
With a few minor improvements to RandomForest in R.

Since this is relatively isolated I'd like to target this for branch-2.1

## How was this patch tested?

manual tests, unit tests

Author: Felix Cheung <felixcheung_m@hotmail.com>

Closes #15746 from felixcheung/rgbt.

(cherry picked from commit 55964c1)
Signed-off-by: Felix Cheung <felixcheung@apache.org>
@asfgit asfgit closed this in 55964c1 Nov 9, 2016
uzadude pushed a commit to uzadude/spark that referenced this pull request Jan 27, 2017
## What changes were proposed in this pull request?

Gradient Boosted Tree in R.
With a few minor improvements to RandomForest in R.

Since this is relatively isolated I'd like to target this for branch-2.1

## How was this patch tested?

manual tests, unit tests

Author: Felix Cheung <felixcheung_m@hotmail.com>

Closes apache#15746 from felixcheung/rgbt.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants