1616 */
1717
1818// scalastyle:off println
19- package org .apache .spark .examples .mllib
19+ package org .apache .spark .examples .ml
2020
2121import java .io .File
2222
2323import com .google .common .io .Files
2424import scopt .OptionParser
2525
2626import org .apache .spark .{SparkConf , SparkContext }
27+ import org .apache .spark .examples .mllib .AbstractParams
2728import org .apache .spark .mllib .linalg .Vector
28- import org .apache .spark .mllib .regression .LabeledPoint
2929import org .apache .spark .mllib .stat .MultivariateOnlineSummarizer
30- import org .apache .spark .mllib .util .MLUtils
31- import org .apache .spark .rdd .RDD
32- import org .apache .spark .sql .{Row , SQLContext , DataFrame }
30+ import org .apache .spark .sql .{DataFrame , Row , SQLContext }
3331
3432/**
35- * An example of how to use [[org.apache.spark.sql.DataFrame ]] as a Dataset for ML. Run with
33+ * An example of how to use [[org.apache.spark.sql.DataFrame ]] for ML. Run with
3634 * {{{
37- * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options]
35+ * ./bin/run-example ml.DataFrameExample [options]
3836 * }}}
3937 * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
4038 */
41- object DatasetExample {
39+ object DataFrameExample {
4240
43- case class Params (
44- input : String = " data/mllib/sample_libsvm_data.txt" ,
45- dataFormat : String = " libsvm" ) extends AbstractParams [Params ]
41+ case class Params (input : String = " data/mllib/sample_libsvm_data.txt" )
42+ extends AbstractParams [Params ]
4643
4744 def main (args : Array [String ]) {
4845 val defaultParams = Params ()
@@ -52,9 +49,6 @@ object DatasetExample {
5249 opt[String ](" input" )
5350 .text(s " input path to dataset " )
5451 .action((x, c) => c.copy(input = x))
55- opt[String ](" dataFormat" )
56- .text(" data format: libsvm (default), dense (deprecated in Spark v1.1)" )
57- .action((x, c) => c.copy(input = x))
5852 checkConfig { params =>
5953 success
6054 }
@@ -69,55 +63,42 @@ object DatasetExample {
6963
7064 def run (params : Params ) {
7165
72- val conf = new SparkConf ().setAppName(s " DatasetExample with $params" )
66+ val conf = new SparkConf ().setAppName(s " DataFrameExample with $params" )
7367 val sc = new SparkContext (conf)
7468 val sqlContext = new SQLContext (sc)
75- import sqlContext .implicits ._ // for implicit conversions
7669
7770 // Load input data
78- val origData : RDD [LabeledPoint ] = params.dataFormat match {
79- case " dense" => MLUtils .loadLabeledPoints(sc, params.input)
80- case " libsvm" => MLUtils .loadLibSVMFile(sc, params.input)
81- }
82- println(s " Loaded ${origData.count()} instances from file: ${params.input}" )
83-
84- // Convert input data to DataFrame explicitly.
85- val df : DataFrame = origData.toDF()
86- println(s " Inferred schema: \n ${df.schema.prettyJson}" )
87- println(s " Converted to DataFrame with ${df.count()} records " )
88-
89- // Select columns
90- val labelsDf : DataFrame = df.select(" label" )
91- val labels : RDD [Double ] = labelsDf.map { case Row (v : Double ) => v }
92- val numLabels = labels.count()
93- val meanLabel = labels.fold(0.0 )(_ + _) / numLabels
94- println(s " Selected label column with average value $meanLabel" )
95-
96- val featuresDf : DataFrame = df.select(" features" )
97- val features : RDD [Vector ] = featuresDf.map { case Row (v : Vector ) => v }
71+ println(s " Loading LIBSVM file with UDT from ${params.input}. " )
72+ val df : DataFrame = sqlContext.read.format(" libsvm" ).load(params.input).cache()
73+ println(" Schema from LIBSVM:" )
74+ df.printSchema()
75+ println(s " Loaded training data as a DataFrame with ${df.count()} records. " )
76+
77+ // Show statistical summary of labels.
78+ val labelSummary = df.describe(" label" )
79+ labelSummary.show()
80+
81+ // Convert features column to an RDD of vectors.
82+ val features = df.select(" features" ).map { case Row (v : Vector ) => v }
9883 val featureSummary = features.aggregate(new MultivariateOnlineSummarizer ())(
9984 (summary, feat) => summary.add(feat),
10085 (sum1, sum2) => sum1.merge(sum2))
10186 println(s " Selected features column with average values: \n ${featureSummary.mean.toString}" )
10287
88+ // Save the records in a parquet file.
10389 val tmpDir = Files .createTempDir()
10490 tmpDir.deleteOnExit()
10591 val outputDir = new File (tmpDir, " dataset" ).toString
10692 println(s " Saving to $outputDir as Parquet file. " )
10793 df.write.parquet(outputDir)
10894
95+ // Load the records back.
10996 println(s " Loading Parquet file with UDT from $outputDir. " )
110- val newDataset = sqlContext.read.parquet(outputDir)
111-
112- println(s " Schema from Parquet: ${newDataset.schema.prettyJson}" )
113- val newFeatures = newDataset.select(" features" ).map { case Row (v : Vector ) => v }
114- val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer ())(
115- (summary, feat) => summary.add(feat),
116- (sum1, sum2) => sum1.merge(sum2))
117- println(s " Selected features column with average values: \n ${newFeaturesSummary.mean.toString}" )
97+ val newDF = sqlContext.read.parquet(outputDir)
98+ println(s " Schema from Parquet: " )
99+ newDF.printSchema()
118100
119101 sc.stop()
120102 }
121-
122103}
123104// scalastyle:on println
0 commit comments