Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ import scala.language.reflectiveCalls

import scopt.OptionParser

import org.apache.spark.ml.tree.DecisionTreeModel
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.ml.{Pipeline, PipelineStage, Transformer}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
Expand Down Expand Up @@ -64,8 +63,6 @@ object DecisionTreeExample {
maxBins: Int = 32,
minInstancesPerNode: Int = 1,
minInfoGain: Double = 0.0,
numTrees: Int = 1,
featureSubsetStrategy: String = "auto",
fracTest: Double = 0.2,
cacheNodeIds: Boolean = false,
checkpointDir: Option[String] = None,
Expand Down Expand Up @@ -123,8 +120,8 @@ object DecisionTreeExample {
.required()
.action((x, c) => c.copy(input = x))
checkConfig { params =>
if (params.fracTest < 0 || params.fracTest > 1) {
failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
if (params.fracTest < 0 || params.fracTest >= 1) {
failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
} else {
success
}
Expand Down Expand Up @@ -200,9 +197,18 @@ object DecisionTreeExample {
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
}
val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache())
val dataframes = splits.map(_.toDF()).map(labelsToStrings)
val training = dataframes(0).cache()
val test = dataframes(1).cache()

(dataframes(0), dataframes(1))
val numTraining = training.count()
val numTest = test.count()
val numFeatures = training.select("features").first().getAs[Vector](0).size
println("Loaded data:")
println(s" numTraining = $numTraining, numTest = $numTest")
println(s" numFeatures = $numFeatures")

(training, test)
}

def run(params: Params) {
Expand All @@ -217,13 +223,6 @@ object DecisionTreeExample {
val (training: DataFrame, test: DataFrame) =
loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest)

val numTraining = training.count()
val numTest = test.count()
val numFeatures = training.select("features").first().getAs[Vector](0).size
println("Loaded data:")
println(s" numTraining = $numTraining, numTest = $numTest")
println(s" numFeatures = $numFeatures")

// Set up Pipeline
val stages = new mutable.ArrayBuffer[PipelineStage]()
// (1) For classification, re-index classes.
Expand All @@ -241,7 +240,7 @@ object DecisionTreeExample {
.setOutputCol("indexedFeatures")
.setMaxCategories(10)
stages += featuresIndexer
// (3) Learn DecisionTree
// (3) Learn Decision Tree
val dt = algo match {
case "classification" =>
new DecisionTreeClassifier()
Expand Down Expand Up @@ -275,62 +274,86 @@ object DecisionTreeExample {
println(s"Training time: $elapsedTime seconds")

// Get the trained Decision Tree from the fitted PipelineModel
val treeModel: DecisionTreeModel = algo match {
algo match {
case "classification" =>
pipelineModel.getModel[DecisionTreeClassificationModel](
val treeModel = pipelineModel.getModel[DecisionTreeClassificationModel](
dt.asInstanceOf[DecisionTreeClassifier])
if (treeModel.numNodes < 20) {
println(treeModel.toDebugString) // Print full model.
} else {
println(treeModel) // Print model summary.
}
case "regression" =>
pipelineModel.getModel[DecisionTreeRegressionModel](dt.asInstanceOf[DecisionTreeRegressor])
case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
if (treeModel.numNodes < 20) {
println(treeModel.toDebugString) // Print full model.
} else {
println(treeModel) // Print model summary.
}

// Predict on training
val trainingFullPredictions = pipelineModel.transform(training).cache()
val trainingPredictions = trainingFullPredictions.select("prediction")
.map(_.getDouble(0))
val trainingLabels = trainingFullPredictions.select(labelColName).map(_.getDouble(0))
// Predict on test data
val testFullPredictions = pipelineModel.transform(test).cache()
val testPredictions = testFullPredictions.select("prediction")
.map(_.getDouble(0))
val testLabels = testFullPredictions.select(labelColName).map(_.getDouble(0))

// For classification, print number of classes for reference.
if (algo == "classification") {
val numClasses =
MetadataUtils.getNumClasses(trainingFullPredictions.schema(labelColName)) match {
case Some(n) => n
case None => throw new RuntimeException(
"DecisionTreeExample had unknown failure when indexing labels for classification.")
val treeModel = pipelineModel.getModel[DecisionTreeRegressionModel](
dt.asInstanceOf[DecisionTreeRegressor])
if (treeModel.numNodes < 20) {
println(treeModel.toDebugString) // Print full model.
} else {
println(treeModel) // Print model summary.
}
println(s"numClasses = $numClasses.")
case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}

// Evaluate model on training, test data
algo match {
case "classification" =>
val trainingAccuracy =
new MulticlassMetrics(trainingPredictions.zip(trainingLabels)).precision
println(s"Train accuracy = $trainingAccuracy")
val testAccuracy =
new MulticlassMetrics(testPredictions.zip(testLabels)).precision
println(s"Test accuracy = $testAccuracy")
println("Training data results:")
evaluateClassificationModel(pipelineModel, training, labelColName)
println("Test data results:")
evaluateClassificationModel(pipelineModel, test, labelColName)
case "regression" =>
val trainingRMSE =
new RegressionMetrics(trainingPredictions.zip(trainingLabels)).rootMeanSquaredError
println(s"Training root mean squared error (RMSE) = $trainingRMSE")
val testRMSE =
new RegressionMetrics(testPredictions.zip(testLabels)).rootMeanSquaredError
println(s"Test root mean squared error (RMSE) = $testRMSE")
println("Training data results:")
evaluateRegressionModel(pipelineModel, training, labelColName)
println("Test data results:")
evaluateRegressionModel(pipelineModel, test, labelColName)
case _ =>
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}

sc.stop()
}

/**
* Evaluate the given ClassificationModel on data. Print the results.
* @param model Must fit ClassificationModel abstraction
* @param data DataFrame with "prediction" and labelColName columns
* @param labelColName Name of the labelCol parameter for the model
*
* TODO: Change model type to ClassificationModel once that API is public. SPARK-5995
*/
private[ml] def evaluateClassificationModel(
model: Transformer,
data: DataFrame,
labelColName: String): Unit = {
val fullPredictions = model.transform(data).cache()
val predictions = fullPredictions.select("prediction").map(_.getDouble(0))
val labels = fullPredictions.select(labelColName).map(_.getDouble(0))
// Print number of classes for reference
val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match {
case Some(n) => n
case None => throw new RuntimeException(
"Unknown failure when indexing labels for classification.")
}
val accuracy = new MulticlassMetrics(predictions.zip(labels)).precision
println(s" Accuracy ($numClasses classes): $accuracy")
}

/**
* Evaluate the given RegressionModel on data. Print the results.
* @param model Must fit RegressionModel abstraction
* @param data DataFrame with "prediction" and labelColName columns
* @param labelColName Name of the labelCol parameter for the model
*
* TODO: Change model type to RegressionModel once that API is public. SPARK-5995
*/
private[ml] def evaluateRegressionModel(
model: Transformer,
data: DataFrame,
labelColName: String): Unit = {
val fullPredictions = model.transform(data).cache()
val predictions = fullPredictions.select("prediction").map(_.getDouble(0))
val labels = fullPredictions.select(labelColName).map(_.getDouble(0))
val RMSE = new RegressionMetrics(predictions.zip(labels)).rootMeanSquaredError
println(s" Root mean squared error (RMSE): $RMSE")
}
}
Loading