Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c2c2fcd
WIP
marmbrus Feb 25, 2016
4687a66
WIP
marmbrus Feb 25, 2016
0bf0d02
WIP: basic read/write workign
marmbrus Feb 25, 2016
1f35b90
WIP: trying to get appending
marmbrus Feb 26, 2016
4bc04e3
working on partitioning
marmbrus Feb 26, 2016
a27b4a6
WIP: many tests passing
marmbrus Feb 26, 2016
159e4c4
WIP: parquet/hive compiling
marmbrus Feb 28, 2016
7299660
:(
marmbrus Feb 29, 2016
049ac1b
much of hive passing
marmbrus Mar 1, 2016
405f284
Merge remote-tracking branch 'apache/master' into fileDataSource
marmbrus Mar 1, 2016
d28300b
more progress
marmbrus Mar 1, 2016
6b13674
WIP
marmbrus Mar 2, 2016
a975f2d
WIP: all but bucketing
marmbrus Mar 2, 2016
5275c41
Still workign on bucketing...
marmbrus Mar 3, 2016
0d4b08a
restore
marmbrus Mar 3, 2016
428a62f
remove
marmbrus Mar 3, 2016
1a41e15
fix all tests
cloud-fan Mar 3, 2016
2a49e8a
Merge pull request #32 from cloud-fan/fileDataSource
marmbrus Mar 3, 2016
023f133
Merge remote-tracking branch 'apache/master' into fileDataSource
marmbrus Mar 3, 2016
83fbb44
TESTS PASSING?\!?
marmbrus Mar 4, 2016
175e78f
cleanup
marmbrus Mar 4, 2016
216078c
style
marmbrus Mar 4, 2016
ac54278
Merge remote-tracking branch 'apache/master' into fileDataSource
marmbrus Mar 4, 2016
af8baff
docs
marmbrus Mar 4, 2016
3b7e3a8
mima
marmbrus Mar 4, 2016
4b53adb
Merge remote-tracking branch 'apache/master' into fileDataSource
marmbrus Mar 5, 2016
bb9e092
Merge remote-tracking branch 'apache/master' into fileDataSource
marmbrus Mar 7, 2016
fd65bcb
comments
marmbrus Mar 7, 2016
3e5c7b7
Merge remote-tracking branch 'apache/master' into fileDataSource
marmbrus Mar 7, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag](
override def getPartitions: Array[Partition] = {
val numParts = rdds.head.partitions.length
if (!rdds.forall(rdd => rdd.partitions.length == numParts)) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
throw new IllegalArgumentException(
s"Can't zip RDDs with unequal numbers of partitions: ${rdds.map(_.partitions.length)}")
}
Array.tabulate[Partition](numParts) { i =>
val prefs = rdds.map(rdd => rdd.preferredLocations(rdd.partitions(i)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,74 +19,23 @@ package org.apache.spark.ml.source.libsvm

import java.io.IOException

import com.google.common.base.Objects
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat

import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._

/**
* LibSVMRelation provides the DataFrame constructed from LibSVM format data.
* @param path File path of LibSVM format
* @param numFeatures The number of features
* @param vectorType The type of vector. It can be 'sparse' or 'dense'
* @param sqlContext The Spark SQLContext
*/
private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
(@transient val sqlContext: SQLContext)
extends HadoopFsRelation with Serializable {

override def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus])
: RDD[Row] = {
val sc = sqlContext.sparkContext
val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
val sparse = vectorType == "sparse"
baseRdd.map { pt =>
val features = if (sparse) pt.features.toSparse else pt.features.toDense
Row(pt.label, features)
}
}

override def hashCode(): Int = {
Objects.hashCode(path, Double.box(numFeatures), vectorType)
}

override def equals(other: Any): Boolean = other match {
case that: LibSVMRelation =>
path == that.path &&
numFeatures == that.numFeatures &&
vectorType == that.vectorType
case _ =>
false
}

override def prepareJobForWrite(job: _root_.org.apache.hadoop.mapreduce.Job):
_root_.org.apache.spark.sql.sources.OutputWriterFactory = {
new OutputWriterFactory {
override def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new LibSVMOutputWriter(path, dataSchema, context)
}
}
}

override def paths: Array[String] = Array(path)

override def dataSchema: StructType = StructType(
StructField("label", DoubleType, nullable = false) ::
StructField("features", new VectorUDT(), nullable = false) :: Nil)
}

import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.collection.BitSet

private[libsvm] class LibSVMOutputWriter(
path: String,
Expand Down Expand Up @@ -124,6 +73,7 @@ private[libsvm] class LibSVMOutputWriter(
recordWriter.close(context)
}
}

/**
* `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]].
* The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and
Expand Down Expand Up @@ -155,7 +105,7 @@ private[libsvm] class LibSVMOutputWriter(
* @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]]
*/
@Since("1.6.0")
class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
class DefaultSource extends FileFormat with DataSourceRegister {

@Since("1.6.0")
override def shortName(): String = "libsvm"
Expand All @@ -167,22 +117,63 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
throw new IOException(s"Illegal schema for libsvm data, schema=${dataSchema}")
}
}
override def inferSchema(
sqlContext: SQLContext,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
Some(
StructType(
StructField("label", DoubleType, nullable = false) ::
StructField("features", new VectorUDT(), nullable = false) :: Nil))
}

override def createRelation(
override def prepareWrite(
sqlContext: SQLContext,
paths: Array[String],
dataSchema: Option[StructType],
partitionColumns: Option[StructType],
parameters: Map[String, String]): HadoopFsRelation = {
val path = if (paths.length == 1) paths(0)
else if (paths.isEmpty) throw new IOException("No input path specified for libsvm data")
else throw new IOException("Multiple input paths are not supported for libsvm data")
if (partitionColumns.isDefined && !partitionColumns.get.isEmpty) {
throw new IOException("Partition is not supported for libsvm data")
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
new OutputWriterFactory {
override def newInstance(
path: String,
bucketId: Option[Int],
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
if (bucketId.isDefined) { sys.error("LibSVM doesn't support bucketing") }
new LibSVMOutputWriter(path, dataSchema, context)
}
}
}

override def buildInternalScan(
sqlContext: SQLContext,
dataSchema: StructType,
requiredColumns: Array[String],
filters: Array[Filter],
bucketSet: Option[BitSet],
inputFiles: Array[FileStatus],
broadcastedConf: Broadcast[SerializableConfiguration],
options: Map[String, String]): RDD[InternalRow] = {
// TODO: This does not handle cases where column pruning has been performed.

verifySchema(dataSchema)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also verify schema when write? i.e. in prepareWrite

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we do already, on line 69

val dataFiles = inputFiles.filterNot(_.getPath.getName startsWith "_")

val path = if (dataFiles.length == 1) dataFiles(0).getPath.toUri.toString
else if (dataFiles.isEmpty) throw new IOException("No input path specified for libsvm data")
else throw new IOException("Multiple input paths are not supported for libsvm data.")

val numFeatures = options.getOrElse("numFeatures", "-1").toInt
val vectorType = options.getOrElse("vectorType", "sparse")

val sc = sqlContext.sparkContext
val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
val sparse = vectorType == "sparse"
baseRdd.map { pt =>
val features = if (sparse) pt.features.toSparse else pt.features.toDense
Row(pt.label, features)
}.mapPartitions { externalRows =>
val converter = RowEncoder(dataSchema)
externalRows.map(converter.toRow)
}
dataSchema.foreach(verifySchema(_))
val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
val vectorType = parameters.getOrElse("vectorType", "sparse")
new LibSVMRelation(path, numFeatures, vectorType)(sqlContext)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.io.{File, IOException}
import com.google.common.base.Charsets
import com.google.common.io.Files

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.SaveMode
Expand Down Expand Up @@ -88,7 +88,8 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
val df = sqlContext.read.format("libsvm").load(path)
val tempDir2 = Utils.createTempDir()
val writepath = tempDir2.toURI.toString
df.write.format("libsvm").mode(SaveMode.Overwrite).save(writepath)
// TODO: Remove requirement to coalesce by supporting mutiple reads.
df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get this, lib svm relation doesn't support multiple reads even before your PR right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. What does "multiple reads" mean here and why coalesce(1) is required? LibSVMRelation uses textFiles to read LibSVM files under the hood, so I'd assume that it can read from multiple part-files?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, honestly I'm not sure where we were implicitly coalescing before, but this is required to make the test case pass.

Before this PR the implementation had a restriction that it throws an error if there is more than one file and I did not try and remove that.


val df2 = sqlContext.read.format("libsvm").load(writepath)
val row1 = df2.first()
Expand All @@ -98,9 +99,8 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {

test("write libsvm data failed due to invalid schema") {
val df = sqlContext.read.format("text").load(path)
val e = intercept[IOException] {
val e = intercept[SparkException] {
df.write.format("libsvm").save(path + "_2")
}
assert(e.getMessage.contains("Illegal schema for libsvm data"))
}
}
6 changes: 5 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonRDD"),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.load"),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.dialectClassName"),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.getSQLDialect")
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.getSQLDialect"),
// SPARK-13664 Replace HadoopFsRelation with FileFormat
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.source.libsvm.LibSVMRelation"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.HadoopFsRelationProvider"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.HadoopFsRelation$FileStatusCache")
) ++ Seq(
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ object DataType {

/** Given the string representation of a type, return its DataType */
private def nameToType(name: String): DataType = {
val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r
val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this? scale must not be negative because 0 <= scale <= precision.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case is hard coded to have a negative value here. I'm not actually sure if thats correct, but I would argue either way that we should be more permissive about parsing (otherwise the error message is Unknown DataType) and then throw a more sensible validation failure.

name match {
case "decimal" => DecimalType.USER_DEFAULT
case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
Expand Down
59 changes: 28 additions & 31 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,14 @@ import java.util.Properties

import scala.collection.JavaConverters._

import org.apache.hadoop.fs.Path
import org.apache.hadoop.util.StringUtils

import org.apache.spark.{Logging, Partition}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource}
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
import org.apache.spark.sql.execution.datasources.json.JSONRelation
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.execution.datasources.json.{InferSchema, JacksonParser, JSONOptions}
import org.apache.spark.sql.execution.streaming.StreamingRelation
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -129,8 +125,6 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
val resolved = ResolvedDataSource(
sqlContext,
userSpecifiedSchema = userSpecifiedSchema,
partitionColumns = Array.empty[String],
bucketSpec = None,
provider = source,
options = extraOptions.toMap)
DataFrame(sqlContext, LogicalRelation(resolved.relation))
Expand All @@ -154,7 +148,17 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*/
@scala.annotation.varargs
def load(paths: String*): DataFrame = {
option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load()
if (paths.isEmpty) {
sqlContext.emptyDataFrame
} else {
sqlContext.baseRelationToDataFrame(
ResolvedDataSource.apply(
sqlContext,
paths = paths,
userSpecifiedSchema = userSpecifiedSchema,
provider = source,
options = extraOptions.toMap).relation)
}
}

/**
Expand Down Expand Up @@ -334,14 +338,20 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
* @since 1.4.0
*/
def json(jsonRDD: RDD[String]): DataFrame = {
sqlContext.baseRelationToDataFrame(
new JSONRelation(
Some(jsonRDD),
maybeDataSchema = userSpecifiedSchema,
maybePartitionSpec = None,
userDefinedPartitionColumns = None,
parameters = extraOptions.toMap)(sqlContext)
)
val parsedOptions: JSONOptions = new JSONOptions(extraOptions.toMap)
val schema = userSpecifiedSchema.getOrElse {
InferSchema.infer(jsonRDD, sqlContext.conf.columnNameOfCorruptRecord, parsedOptions)
}

new DataFrame(
sqlContext,
LogicalRDD(
schema.toAttributes,
JacksonParser.parse(
jsonRDD,
schema,
sqlContext.conf.columnNameOfCorruptRecord,
parsedOptions))(sqlContext))
}

/**
Expand All @@ -363,20 +373,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*/
@scala.annotation.varargs
def parquet(paths: String*): DataFrame = {
if (paths.isEmpty) {
sqlContext.emptyDataFrame
} else {
val globbedPaths = paths.flatMap { path =>
val hdfsPath = new Path(path)
val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
SparkHadoopUtil.get.globPathIfNecessary(qualified)
}.toArray

sqlContext.baseRelationToDataFrame(
new ParquetRelation(
globbedPaths.map(_.toString), userSpecifiedSchema, None, extraOptions.toMap)(sqlContext))
}
format("parquet").load(paths: _*)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,6 @@ final class DataFrameWriter private[sql](df: DataFrame) {
case (true, SaveMode.ErrorIfExists) =>
throw new AnalysisException(s"Table $tableIdent already exists.")

case (true, SaveMode.Append) =>
// If it is Append, we just ask insertInto to handle it. We will not use insertInto
// to handle saveAsTable with Overwrite because saveAsTable can change the schema of
// the table. But, insertInto with Overwrite requires the schema of data be the same
// the schema of the table.
insertInto(tableIdent)

case _ =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I consolidated the code paths. Any writing to HadoopFsRelations now goes through InsertIntoHadoopFsRelation.

val cmd =
CreateTableUsingAsSelect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => ParquetSource}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation}
Expand Down Expand Up @@ -226,16 +226,17 @@ private[sql] object PhysicalRDD {
rdd: RDD[InternalRow],
relation: BaseRelation,
metadata: Map[String, String] = Map.empty): PhysicalRDD = {
val outputUnsafeRows = if (relation.isInstanceOf[ParquetRelation]) {
// The vectorized parquet reader does not produce unsafe rows.
!SQLContext.getActive().get.conf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED)
} else {
// All HadoopFsRelations output UnsafeRows
relation.isInstanceOf[HadoopFsRelation]

val outputUnsafeRows = relation match {
case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] =>
!SQLContext.getActive().get.conf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED)
case _: HadoopFsRelation => true
case _ => false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, after merging this and all the other planned HadoopFsRelation related updates, we will still have built-in data sources returning UnsafeRows and external data source packages returning Rows, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there is not interface to do that today since the only version returns an RDD[InternalRow], but SPARK-13682 tracks fixing that.

When we do this we should avoid hacks that rely on erasure in the bytecode. Instead just have an internal function that returns InternalRow by converting the result of calling the public version, which returns Row. Internal implementations like parquet can circumvent this extra step and just override the conversion function.

}

val bucketSpec = relation match {
case r: HadoopFsRelation => r.getBucketSpec
// TODO: this should be closer to bucket planning.
case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled() => r.bucketSpec
case _ => None
}

Expand Down
Loading