@@ -56,8 +56,11 @@ println(metrics.confusionMatrix)
5656// the Iris DataSet has three classes
5757val numClasses = 3
5858
59- val fprs = (0 until numClasses).map(label => label + "\t" + metrics.falsePositiveRate(label.toDouble)).mkString("\n")
60- println("label\tfpr\n" + fprs)
59+ println("label\tfpr\n")
60+ (0 until numClasses).foreach { index =>
61+ val label = index.toDouble
62+ println(label + "\t" + metrics.falsePositiveRate(label))
63+ }
6164{% endhighlight %}
6265</div >
6366<div data-lang =" java " markdown =" 1 " >
@@ -67,38 +70,37 @@ import org.apache.spark.SparkConf;
6770import org.apache.spark.api.java.JavaSparkContext;
6871import org.apache.spark.ml.classification.LogisticRegression;
6972import org.apache.spark.ml.classification.OneVsRest;
73+ import org.apache.spark.ml.classification.OneVsRestModel;
7074import org.apache.spark.mllib.evaluation.MulticlassMetrics;
7175import org.apache.spark.mllib.linalg.Matrix;
7276import org.apache.spark.mllib.regression.LabeledPoint;
7377import org.apache.spark.mllib.util.MLUtils;
7478import org.apache.spark.rdd.RDD;
7579import org.apache.spark.sql.DataFrame;
76- import org.apache.spark.sql.Row;
7780import org.apache.spark.sql.SQLContext;
7881
7982SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample");
8083JavaSparkContext jsc = new JavaSparkContext(conf);
8184SQLContext jsql = new SQLContext(jsc);
82-
85+
8386RDD<LabeledPoint > data = MLUtils.loadLibSVMFile(jsc.sc(),
8487 "data/mllib/sample_multiclass_classification_data.txt");
8588
86- RDD<LabeledPoint >[ ] split = data.randomSplit(new double[ ] {0.7, 0.3}, 12345);
87- RDD<LabeledPoint > train = split[ 0] ;
88- RDD<LabeledPoint > test = split[ 1] ;
89+ DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class);
90+ DataFrame[ ] splits = dataFrame.randomSplit(new double[ ] {0.7, 0.3}, 12345);
91+ DataFrame train = splits[ 0] ;
92+ DataFrame test = splits[ 1] ;
8993
9094// instantiate the One Vs Rest Classifier
9195OneVsRest ovr = new OneVsRest().setClassifier(new LogisticRegression());
92-
96+
9397// train the multiclass model
94- DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class);
95- OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache());
98+ OneVsRestModel ovrModel = ovr.fit(train.cache());
9699
97100// score the model on test data
98- DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class);
99101DataFrame predictions = ovrModel
100- .transform(testDataFrame.cache() )
101- .select("prediction", "label");
102+ .transform(test )
103+ .select("prediction", "label");
102104
103105// obtain metrics
104106MulticlassMetrics metrics = new MulticlassMetrics(predictions);
@@ -109,20 +111,19 @@ System.out.println("Confusion Matrix");
109111System.out.println(confusionMatrix);
110112
111113// compute the false positive rate per label
112- StringBuilder results = new StringBuilder ();
113- results.append ("label\tfpr\n");
114+ System.out.println ();
115+ System.out.println ("label\tfpr\n");
114116
115117// the Iris DataSet has three classes
116118int numClasses = 3;
117-
118- for (int label = 0; label < numClasses; label++) {
119- results.append (label);
120- results.append ("\t");
121- results.append (metrics.falsePositiveRate((double) label));
122- results.append("\n" );
119+ for (int index = 0; index < numClasses; index++) {
120+ double label = (double) index;
121+ System.out.print (label);
122+ System.out.print ("\t");
123+ System.out.print (metrics.falsePositiveRate(label));
124+ System.out.println( );
123125}
124- System.out.println();
125- System.out.println(results);
126+
126127{% endhighlight %}
127128</div >
128129</div >
0 commit comments