Skip to content

Commit fc8bbe3

Browse files
committed
GBT SparkR
1 parent 7c37869 commit fc8bbe3

File tree

8 files changed

+590
-25
lines changed

8 files changed

+590
-25
lines changed

R/pkg/NAMESPACE

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ exportMethods("glm",
4545
"spark.als",
4646
"spark.kstest",
4747
"spark.logit",
48-
"spark.randomForest")
48+
"spark.randomForest",
49+
"spark.gbt")
4950

5051
# Job group lifecycle management methods
5152
export("setJobGroup",
@@ -353,7 +354,9 @@ export("as.DataFrame",
353354
"read.ml",
354355
"print.summary.KSTest",
355356
"print.summary.RandomForestRegressionModel",
356-
"print.summary.RandomForestClassificationModel")
357+
"print.summary.RandomForestClassificationModel",
358+
"print.summary.GBTRegressionModel",
359+
"print.summary.GBTClassificationModel")
357360

358361
export("structField",
359362
"structField.jobj",
@@ -380,6 +383,8 @@ S3method(print, summary.GeneralizedLinearRegressionModel)
380383
S3method(print, summary.KSTest)
381384
S3method(print, summary.RandomForestRegressionModel)
382385
S3method(print, summary.RandomForestClassificationModel)
386+
S3method(print, summary.GBTRegressionModel)
387+
S3method(print, summary.GBTClassificationModel)
383388
S3method(structField, character)
384389
S3method(structField, jobj)
385390
S3method(structType, jobj)

R/pkg/R/generics.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,6 +1343,10 @@ setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") })
13431343
setGeneric("spark.gaussianMixture",
13441344
function(data, formula, ...) { standardGeneric("spark.gaussianMixture") })
13451345

1346+
#' @rdname spark.gbt
1347+
#' @export
1348+
setGeneric("spark.gbt", function(data, formula, ...) { standardGeneric("spark.gbt") })
1349+
13461350
#' @rdname spark.glm
13471351
#' @export
13481352
setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") })

R/pkg/R/mllib.R

Lines changed: 231 additions & 18 deletions
Large diffs are not rendered by default.

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -939,4 +939,55 @@ test_that("spark.randomForest Classification", {
939939
unlink(modelPath)
940940
})
941941

942+
test_that("spark.gbt", {
943+
# regression
944+
data <- suppressWarnings(createDataFrame(longley))
945+
model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123)
946+
predictions <- collect(predict(model, data))
947+
expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187,
948+
63.221, 63.639, 64.989, 63.761,
949+
66.019, 67.857, 68.169, 66.513,
950+
68.655, 69.564, 69.331, 70.551),
951+
tolerance = 1e-4)
952+
stats <- summary(model)
953+
expect_equal(stats$numTrees, 20)
954+
955+
modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp")
956+
write.ml(model, modelPath)
957+
expect_error(write.ml(model, modelPath))
958+
write.ml(model, modelPath, overwrite = TRUE)
959+
model2 <- read.ml(modelPath)
960+
stats2 <- summary(model2)
961+
expect_equal(stats$formula, stats2$formula)
962+
expect_equal(stats$numFeatures, stats2$numFeatures)
963+
expect_equal(stats$features, stats2$features)
964+
expect_equal(stats$featureImportances, stats2$featureImportances)
965+
expect_equal(stats$numTrees, stats2$numTrees)
966+
expect_equal(stats$treeWeights, stats2$treeWeights)
967+
968+
unlink(modelPath)
969+
970+
# classification
971+
# label must be binary - GBTClassifier currently only supports binary classification.
972+
data <- suppressWarnings(createDataFrame(iris[iris$Species != "virginica", ]))
973+
model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification")
974+
stats <- summary(model)
975+
expect_equal(stats$numFeatures, 2)
976+
expect_equal(stats$numTrees, 20)
977+
expect_error(capture.output(stats), NA)
978+
expect_true(length(capture.output(stats)) > 6)
979+
980+
modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp")
981+
write.ml(model, modelPath)
982+
expect_error(write.ml(model, modelPath))
983+
write.ml(model, modelPath, overwrite = TRUE)
984+
model2 <- read.ml(modelPath)
985+
stats2 <- summary(model2)
986+
expect_equal(stats$depth, stats2$depth)
987+
expect_equal(stats$numNodes, stats2$numNodes)
988+
expect_equal(stats$numClasses, stats2$numClasses)
989+
990+
unlink(modelPath)
991+
})
992+
942993
sparkR.session.stop()
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.r
19+
20+
import org.apache.hadoop.fs.Path
21+
import org.json4s._
22+
import org.json4s.JsonDSL._
23+
import org.json4s.jackson.JsonMethods._
24+
25+
import org.apache.spark.ml.{Pipeline, PipelineModel}
26+
import org.apache.spark.ml.attribute.AttributeGroup
27+
import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}
28+
import org.apache.spark.ml.feature.RFormula
29+
import org.apache.spark.ml.linalg.Vector
30+
import org.apache.spark.ml.util._
31+
import org.apache.spark.sql.{DataFrame, Dataset}
32+
33+
private[r] class GBTClassifierWrapper private (
34+
val pipeline: PipelineModel,
35+
val formula: String,
36+
val features: Array[String]) extends MLWritable {
37+
38+
private val DTModel: GBTClassificationModel =
39+
pipeline.stages(1).asInstanceOf[GBTClassificationModel]
40+
41+
lazy val numFeatures: Int = DTModel.numFeatures
42+
lazy val featureImportances: Vector = DTModel.featureImportances
43+
lazy val numTrees: Int = DTModel.getNumTrees
44+
lazy val treeWeights: Array[Double] = DTModel.treeWeights
45+
46+
def summary: String = DTModel.toDebugString
47+
48+
def transform(dataset: Dataset[_]): DataFrame = {
49+
pipeline.transform(dataset).drop(DTModel.getFeaturesCol)
50+
}
51+
52+
override def write: MLWriter = new
53+
GBTClassifierWrapper.GBTClassifierWrapperWriter(this)
54+
}
55+
56+
private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] {
57+
def fit( // scalastyle:ignore
58+
data: DataFrame,
59+
formula: String,
60+
maxDepth: Int,
61+
maxBins: Int,
62+
maxIter: Int,
63+
stepSize: Double,
64+
minInstancesPerNode: Int,
65+
minInfoGain: Double,
66+
checkpointInterval: Int,
67+
lossType: String,
68+
seed: String,
69+
subsamplingRate: Double,
70+
maxMemoryInMB: Int,
71+
cacheNodeIds: Boolean): GBTClassifierWrapper = {
72+
73+
val rFormula = new RFormula()
74+
.setFormula(formula)
75+
RWrapperUtils.checkDataColumns(rFormula, data)
76+
val rFormulaModel = rFormula.fit(data)
77+
78+
// get feature names from output schema
79+
val schema = rFormulaModel.transform(data).schema
80+
val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
81+
.attributes.get
82+
val features = featureAttrs.map(_.name.get)
83+
84+
// assemble and fit the pipeline
85+
val rfc = new GBTClassifier()
86+
.setMaxDepth(maxDepth)
87+
.setMaxBins(maxBins)
88+
.setMaxIter(maxIter)
89+
.setStepSize(stepSize)
90+
.setMinInstancesPerNode(minInstancesPerNode)
91+
.setMinInfoGain(minInfoGain)
92+
.setCheckpointInterval(checkpointInterval)
93+
.setLossType(lossType)
94+
.setSubsamplingRate(subsamplingRate)
95+
.setMaxMemoryInMB(maxMemoryInMB)
96+
.setCacheNodeIds(cacheNodeIds)
97+
.setFeaturesCol(rFormula.getFeaturesCol)
98+
if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong)
99+
100+
val pipeline = new Pipeline()
101+
.setStages(Array(rFormulaModel, rfc))
102+
.fit(data)
103+
104+
new GBTClassifierWrapper(pipeline, formula, features)
105+
}
106+
107+
override def read: MLReader[GBTClassifierWrapper] = new GBTClassifierWrapperReader
108+
109+
override def load(path: String): GBTClassifierWrapper = super.load(path)
110+
111+
class GBTClassifierWrapperWriter(instance: GBTClassifierWrapper)
112+
extends MLWriter {
113+
114+
override protected def saveImpl(path: String): Unit = {
115+
val rMetadataPath = new Path(path, "rMetadata").toString
116+
val pipelinePath = new Path(path, "pipeline").toString
117+
118+
val rMetadata = ("class" -> instance.getClass.getName) ~
119+
("formula" -> instance.formula) ~
120+
("features" -> instance.features.toSeq)
121+
val rMetadataJson: String = compact(render(rMetadata))
122+
123+
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
124+
instance.pipeline.save(pipelinePath)
125+
}
126+
}
127+
128+
class GBTClassifierWrapperReader extends MLReader[GBTClassifierWrapper] {
129+
130+
override def load(path: String): GBTClassifierWrapper = {
131+
implicit val format = DefaultFormats
132+
val rMetadataPath = new Path(path, "rMetadata").toString
133+
val pipelinePath = new Path(path, "pipeline").toString
134+
val pipeline = PipelineModel.load(pipelinePath)
135+
136+
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
137+
val rMetadata = parse(rMetadataStr)
138+
val formula = (rMetadata \ "formula").extract[String]
139+
val features = (rMetadata \ "features").extract[Array[String]]
140+
141+
new GBTClassifierWrapper(pipeline, formula, features)
142+
}
143+
}
144+
}
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.r
19+
20+
import org.apache.hadoop.fs.Path
21+
import org.json4s._
22+
import org.json4s.JsonDSL._
23+
import org.json4s.jackson.JsonMethods._
24+
25+
import org.apache.spark.ml.{Pipeline, PipelineModel}
26+
import org.apache.spark.ml.attribute.AttributeGroup
27+
import org.apache.spark.ml.feature.RFormula
28+
import org.apache.spark.ml.linalg.Vector
29+
import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}
30+
import org.apache.spark.ml.util._
31+
import org.apache.spark.sql.{DataFrame, Dataset}
32+
33+
private[r] class GBTRegressorWrapper private (
34+
val pipeline: PipelineModel,
35+
val formula: String,
36+
val features: Array[String]) extends MLWritable {
37+
38+
private val DTModel: GBTRegressionModel =
39+
pipeline.stages(1).asInstanceOf[GBTRegressionModel]
40+
41+
lazy val numFeatures: Int = DTModel.numFeatures
42+
lazy val featureImportances: Vector = DTModel.featureImportances
43+
lazy val numTrees: Int = DTModel.getNumTrees
44+
lazy val treeWeights: Array[Double] = DTModel.treeWeights
45+
46+
def summary: String = DTModel.toDebugString
47+
48+
def transform(dataset: Dataset[_]): DataFrame = {
49+
pipeline.transform(dataset).drop(DTModel.getFeaturesCol)
50+
}
51+
52+
override def write: MLWriter = new
53+
GBTRegressorWrapper.GBTRegressorWrapperWriter(this)
54+
}
55+
56+
private[r] object GBTRegressorWrapper extends MLReadable[GBTRegressorWrapper] {
57+
def fit( // scalastyle:ignore
58+
data: DataFrame,
59+
formula: String,
60+
maxDepth: Int,
61+
maxBins: Int,
62+
maxIter: Int,
63+
stepSize: Double,
64+
minInstancesPerNode: Int,
65+
minInfoGain: Double,
66+
checkpointInterval: Int,
67+
lossType: String,
68+
seed: String,
69+
subsamplingRate: Double,
70+
maxMemoryInMB: Int,
71+
cacheNodeIds: Boolean): GBTRegressorWrapper = {
72+
73+
val rFormula = new RFormula()
74+
.setFormula(formula)
75+
RWrapperUtils.checkDataColumns(rFormula, data)
76+
val rFormulaModel = rFormula.fit(data)
77+
78+
// get feature names from output schema
79+
val schema = rFormulaModel.transform(data).schema
80+
val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
81+
.attributes.get
82+
val features = featureAttrs.map(_.name.get)
83+
84+
// assemble and fit the pipeline
85+
val rfr = new GBTRegressor()
86+
.setMaxDepth(maxDepth)
87+
.setMaxBins(maxBins)
88+
.setMaxIter(maxIter)
89+
.setStepSize(stepSize)
90+
.setMinInstancesPerNode(minInstancesPerNode)
91+
.setMinInfoGain(minInfoGain)
92+
.setCheckpointInterval(checkpointInterval)
93+
.setLossType(lossType)
94+
.setSubsamplingRate(subsamplingRate)
95+
.setMaxMemoryInMB(maxMemoryInMB)
96+
.setCacheNodeIds(cacheNodeIds)
97+
.setFeaturesCol(rFormula.getFeaturesCol)
98+
if (seed != null && seed.length > 0) rfr.setSeed(seed.toLong)
99+
100+
val pipeline = new Pipeline()
101+
.setStages(Array(rFormulaModel, rfr))
102+
.fit(data)
103+
104+
new GBTRegressorWrapper(pipeline, formula, features)
105+
}
106+
107+
override def read: MLReader[GBTRegressorWrapper] = new GBTRegressorWrapperReader
108+
109+
override def load(path: String): GBTRegressorWrapper = super.load(path)
110+
111+
class GBTRegressorWrapperWriter(instance: GBTRegressorWrapper)
112+
extends MLWriter {
113+
114+
override protected def saveImpl(path: String): Unit = {
115+
val rMetadataPath = new Path(path, "rMetadata").toString
116+
val pipelinePath = new Path(path, "pipeline").toString
117+
118+
val rMetadata = ("class" -> instance.getClass.getName) ~
119+
("formula" -> instance.formula) ~
120+
("features" -> instance.features.toSeq)
121+
val rMetadataJson: String = compact(render(rMetadata))
122+
123+
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
124+
instance.pipeline.save(pipelinePath)
125+
}
126+
}
127+
128+
class GBTRegressorWrapperReader extends MLReader[GBTRegressorWrapper] {
129+
130+
override def load(path: String): GBTRegressorWrapper = {
131+
implicit val format = DefaultFormats
132+
val rMetadataPath = new Path(path, "rMetadata").toString
133+
val pipelinePath = new Path(path, "pipeline").toString
134+
val pipeline = PipelineModel.load(pipelinePath)
135+
136+
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
137+
val rMetadata = parse(rMetadataStr)
138+
val formula = (rMetadata \ "formula").extract[String]
139+
val features = (rMetadata \ "features").extract[Array[String]]
140+
141+
new GBTRegressorWrapper(pipeline, formula, features)
142+
}
143+
}
144+
}

mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ private[r] object RWrappers extends MLReader[Object] {
6060
RandomForestRegressorWrapper.load(path)
6161
case "org.apache.spark.ml.r.RandomForestClassifierWrapper" =>
6262
RandomForestClassifierWrapper.load(path)
63+
case "org.apache.spark.ml.r.GBTRegressorWrapper" =>
64+
GBTRegressorWrapper.load(path)
65+
case "org.apache.spark.ml.r.GBTClassifierWrapper" =>
66+
GBTClassifierWrapper.load(path)
6367
case _ =>
6468
throw new SparkException(s"SparkR read.ml does not support load $className")
6569
}

0 commit comments

Comments
 (0)