Skip to content

Commit 46c41b1

Browse files
author
Ram Sriharsha
committed
cleanup
1 parent 2f76295 commit 46c41b1

File tree

1 file changed

+24
-23
lines changed

1 file changed

+24
-23
lines changed

docs/ml-ensembles.md

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,11 @@ println(metrics.confusionMatrix)
5656
// the Iris DataSet has three classes
5757
val numClasses = 3
5858

59-
val fprs = (0 until numClasses).map(label => label + "\t" + metrics.falsePositiveRate(label.toDouble)).mkString("\n")
60-
println("label\tfpr\n" + fprs)
59+
println("label\tfpr\n")
60+
(0 until numClasses).foreach { index =>
61+
val label = index.toDouble
62+
println(label + "\t" + metrics.falsePositiveRate(label))
63+
}
6164
{% endhighlight %}
6265
</div>
6366
<div data-lang="java" markdown="1">
@@ -67,38 +70,37 @@ import org.apache.spark.SparkConf;
6770
import org.apache.spark.api.java.JavaSparkContext;
6871
import org.apache.spark.ml.classification.LogisticRegression;
6972
import org.apache.spark.ml.classification.OneVsRest;
73+
import org.apache.spark.ml.classification.OneVsRestModel;
7074
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
7175
import org.apache.spark.mllib.linalg.Matrix;
7276
import org.apache.spark.mllib.regression.LabeledPoint;
7377
import org.apache.spark.mllib.util.MLUtils;
7478
import org.apache.spark.rdd.RDD;
7579
import org.apache.spark.sql.DataFrame;
76-
import org.apache.spark.sql.Row;
7780
import org.apache.spark.sql.SQLContext;
7881

7982
SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample");
8083
JavaSparkContext jsc = new JavaSparkContext(conf);
8184
SQLContext jsql = new SQLContext(jsc);
82-
85+
8386
RDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(),
8487
"data/mllib/sample_multiclass_classification_data.txt");
8588

86-
RDD<LabeledPoint>[] split = data.randomSplit(new double[]{0.7, 0.3}, 12345);
87-
RDD<LabeledPoint> train = split[0];
88-
RDD<LabeledPoint> test = split[1];
89+
DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class);
90+
DataFrame[] splits = dataFrame.randomSplit(new double[]{0.7, 0.3}, 12345);
91+
DataFrame train = splits[0];
92+
DataFrame test = splits[1];
8993

9094
// instantiate the One Vs Rest Classifier
9195
OneVsRest ovr = new OneVsRest().setClassifier(new LogisticRegression());
92-
96+
9397
// train the multiclass model
94-
DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class);
95-
OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache());
98+
OneVsRestModel ovrModel = ovr.fit(train.cache());
9699

97100
// score the model on test data
98-
DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class);
99101
DataFrame predictions = ovrModel
100-
.transform(testDataFrame.cache())
101-
.select("prediction", "label");
102+
.transform(test)
103+
.select("prediction", "label");
102104

103105
// obtain metrics
104106
MulticlassMetrics metrics = new MulticlassMetrics(predictions);
@@ -109,20 +111,19 @@ System.out.println("Confusion Matrix");
109111
System.out.println(confusionMatrix);
110112

111113
// compute the false positive rate per label
112-
StringBuilder results = new StringBuilder();
113-
results.append("label\tfpr\n");
114+
System.out.println();
115+
System.out.println("label\tfpr\n");
114116

115117
// the Iris DataSet has three classes
116118
int numClasses = 3;
117-
118-
for (int label = 0; label < numClasses; label++) {
119-
results.append(label);
120-
results.append("\t");
121-
results.append(metrics.falsePositiveRate((double) label));
122-
results.append("\n");
119+
for (int index = 0; index < numClasses; index++) {
120+
double label = (double) index;
121+
System.out.print(label);
122+
System.out.print("\t");
123+
System.out.print(metrics.falsePositiveRate(label));
124+
System.out.println();
123125
}
124-
System.out.println();
125-
System.out.println(results);
126+
126127
{% endhighlight %}
127128
</div>
128129
</div>

0 commit comments

Comments
 (0)