Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,30 @@
*/

// scalastyle:off println
package org.apache.spark.examples.mllib
package org.apache.spark.examples.ml

import java.io.File

import com.google.common.io.Files
import scopt.OptionParser

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext, DataFrame}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

/**
* An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with
* An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with
* {{{
* ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options]
* ./bin/run-example ml.DataFrameExample [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
object DatasetExample {
object DataFrameExample {

case class Params(
input: String = "data/mllib/sample_libsvm_data.txt",
dataFormat: String = "libsvm") extends AbstractParams[Params]
case class Params(input: String = "data/mllib/sample_libsvm_data.txt")
extends AbstractParams[Params]

def main(args: Array[String]) {
val defaultParams = Params()
Expand All @@ -52,9 +49,6 @@ object DatasetExample {
opt[String]("input")
.text(s"input path to dataset")
.action((x, c) => c.copy(input = x))
opt[String]("dataFormat")
.text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
.action((x, c) => c.copy(input = x))
checkConfig { params =>
success
}
Expand All @@ -69,55 +63,42 @@ object DatasetExample {

def run(params: Params) {

val conf = new SparkConf().setAppName(s"DatasetExample with $params")
val conf = new SparkConf().setAppName(s"DataFrameExample with $params")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._ // for implicit conversions

// Load input data
val origData: RDD[LabeledPoint] = params.dataFormat match {
case "dense" => MLUtils.loadLabeledPoints(sc, params.input)
case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input)
}
println(s"Loaded ${origData.count()} instances from file: ${params.input}")

// Convert input data to DataFrame explicitly.
val df: DataFrame = origData.toDF()
println(s"Inferred schema:\n${df.schema.prettyJson}")
println(s"Converted to DataFrame with ${df.count()} records")

// Select columns
val labelsDf: DataFrame = df.select("label")
val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v }
val numLabels = labels.count()
val meanLabel = labels.fold(0.0)(_ + _) / numLabels
println(s"Selected label column with average value $meanLabel")

val featuresDf: DataFrame = df.select("features")
val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v }
println(s"Loading LIBSVM file with UDT from ${params.input}.")
val df: DataFrame = sqlContext.read.format("libsvm").load(params.input).cache()
println("Schema from LIBSVM:")
df.printSchema()
println(s"Loaded training data as a DataFrame with ${df.count()} records.")

// Show statistical summary of labels.
val labelSummary = df.describe("label")
labelSummary.show()

// Convert features column to an RDD of vectors.
val features = df.select("features").map { case Row(v: Vector) => v }
val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
(sum1, sum2) => sum1.merge(sum2))
println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")

// Save the records in a parquet file.
val tmpDir = Files.createTempDir()
tmpDir.deleteOnExit()
val outputDir = new File(tmpDir, "dataset").toString
println(s"Saving to $outputDir as Parquet file.")
df.write.parquet(outputDir)

// Load the records back.
println(s"Loading Parquet file with UDT from $outputDir.")
val newDataset = sqlContext.read.parquet(outputDir)

println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v }
val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
(sum1, sum2) => sum1.merge(sum2))
println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}")
val newDF = sqlContext.read.parquet(outputDir)
println(s"Schema from Parquet:")
newDF.printSchema()

sc.stop()
}

}
// scalastyle:on println