Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation}
import org.apache.spark.sql.execution.datasources.csv._
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
import org.apache.spark.sql.execution.datasources.json.{InferSchema, JacksonParser, JSONOptions}
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -408,6 +409,39 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
@scala.annotation.varargs
def csv(paths: String*): DataFrame = format("csv").load(paths : _*)

/**
* Loads an `Dataset[String]` storing CSV objects (one object per record) and
* returns the result as a [[DataFrame]].
*
* Unless the schema is specified using [[schema]] function, this function goes through the
* input once to determine the input schema.
*
* @param csvDS input Dataset with one CSV object per record
* @since 2.1.0
*/
def csv(csvDS: Dataset[String]): DataFrame = {
val csvRDD = csvDS.rdd
val parsedOptions: CSVOptions = new CSVOptions(extraOptions.toMap)
val header = CSVRelation.getHeader(csvRDD, parsedOptions)
val parsedRDD = CSVRelation.tokenRdd(parsedOptions, header, csvRDD)
val schema = userSpecifiedSchema.getOrElse {
CSVInferSchema.infer(parsedRDD, header, parsedOptions)
}

val parser = CSVRelation.csvParser(schema, schema.fields.map(_.name), parsedOptions)
var numMalformedRecords = 0
val rows = parsedRDD.flatMap { recordTokens =>
val row = parser(recordTokens, numMalformedRecords)
if (row.isEmpty) {
numMalformedRecords += 1
}
row
}
Dataset.ofRows(
sparkSession,
LogicalRDD(schema.toAttributes, rows)(sparkSession))
}

/**
* Loads a Parquet file, returning the result as a [[DataFrame]]. See the documentation
* on the other overloaded `parquet()` method for more details.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,37 +54,18 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {

// TODO: Move filtering.
val paths = files.filterNot(_.getPath.getName startsWith "_").map(_.getPath.toString)
val rdd = baseRdd(sparkSession, csvOptions, paths)
val firstLine = findFirstLine(csvOptions, rdd)
val firstRow = new CsvReader(csvOptions).parseLine(firstLine)

val header = if (csvOptions.headerFlag) {
firstRow.zipWithIndex.map { case (value, index) =>
if (value == null || value.isEmpty || value == csvOptions.nullValue) s"_c$index" else value
}
} else {
firstRow.zipWithIndex.map { case (value, index) => s"_c$index" }
}

val parsedRdd = tokenRdd(sparkSession, csvOptions, header, paths)
val schema = if (csvOptions.inferSchemaFlag) {
CSVInferSchema.infer(parsedRdd, header, csvOptions)
} else {
// By default fields are assumed to be StringType
val schemaFields = header.map { fieldName =>
StructField(fieldName.toString, StringType, nullable = true)
}
StructType(schemaFields)
}
Some(schema)
val rdd = CSVRelation.baseRdd(sparkSession, csvOptions, paths)
val header = CSVRelation.getHeader(rdd, csvOptions)
val parsedRdd = CSVRelation.tokenRdd(csvOptions, header, rdd)
Some(CSVInferSchema.infer(parsedRdd, header, csvOptions))
}

override def prepareWrite(
sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
verifySchema(dataSchema)
CSVRelation.verifySchema(dataSchema)
val conf = job.getConfiguration
val csvOptions = new CSVOptions(options)
csvOptions.compressionCodec.foreach { codec =>
Expand Down Expand Up @@ -136,68 +117,4 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
}
}
}

private def baseRdd(
sparkSession: SparkSession,
options: CSVOptions,
inputPaths: Seq[String]): RDD[String] = {
readText(sparkSession, options, inputPaths.mkString(","))
}

private def tokenRdd(
sparkSession: SparkSession,
options: CSVOptions,
header: Array[String],
inputPaths: Seq[String]): RDD[Array[String]] = {
val rdd = baseRdd(sparkSession, options, inputPaths)
// Make sure firstLine is materialized before sending to executors
val firstLine = if (options.headerFlag) findFirstLine(options, rdd) else null
CSVRelation.univocityTokenizer(rdd, firstLine, options)
}

/**
* Returns the first line of the first non-empty file in path
*/
private def findFirstLine(options: CSVOptions, rdd: RDD[String]): String = {
if (options.isCommentSet) {
val comment = options.comment.toString
rdd.filter { line =>
line.trim.nonEmpty && !line.startsWith(comment)
}.first()
} else {
rdd.filter { line =>
line.trim.nonEmpty
}.first()
}
}

private def readText(
sparkSession: SparkSession,
options: CSVOptions,
location: String): RDD[String] = {
if (Charset.forName(options.charset) == StandardCharsets.UTF_8) {
sparkSession.sparkContext.textFile(location)
} else {
val charset = options.charset
sparkSession.sparkContext
.hadoopFile[LongWritable, Text, TextInputFormat](location)
.mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset)))
}
}

private def verifySchema(schema: StructType): Unit = {
def verifyType(dataType: DataType): Unit = dataType match {
case ByteType | ShortType | IntegerType | LongType | FloatType |
DoubleType | BooleanType | _: DecimalType | TimestampType |
DateType | StringType =>

case udt: UserDefinedType[_] => verifyType(udt.sqlType)

case _ =>
throw new UnsupportedOperationException(
s"CSV data source does not support ${dataType.simpleString} data type.")
}

schema.foreach(field => verifyType(field.dataType))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

private[csv] object CSVInferSchema {
object CSVInferSchema {

/**
* Similar to the JSON schema inference
Expand All @@ -42,16 +42,23 @@ private[csv] object CSVInferSchema {
tokenRdd: RDD[Array[String]],
header: Array[String],
options: CSVOptions): StructType = {
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val rootTypes: Array[DataType] =
tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)
val structFields = if (options.inferSchemaFlag) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is used in both csv.DefaultSource and DataFrameReader.csv(ds: Dataset[String]). So I refactored it here to take care both the default schema type and inferSchemaFlag=true cases.

val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val rootTypes: Array[DataType] =
tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)

val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
val dType = rootType match {
case _: NullType => StringType
case other => other
header.zip(rootTypes).map { case (thisHeader, rootType) =>
val dType = rootType match {
case _: NullType => StringType
case other => other
}
StructField(thisHeader, dType, nullable = true)
}
} else {
// By default fields are assumed to be StringType
header.map { fieldName =>
StructField(fieldName.toString, StringType, nullable = true)
}
StructField(thisHeader, dType, nullable = true)
}

StructType(structFields)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.commons.lang3.time.FastDateFormat
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes}

private[csv] class CSVOptions(@transient private val parameters: Map[String, String])
class CSVOptions(@transient private val parameters: Map[String, String])
extends Logging with Serializable {

private def getChar(paramName: String, default: Char): Char = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

package org.apache.spark.sql.execution.datasources.csv

import java.nio.charset.{Charset, StandardCharsets}

import scala.util.control.NonFatal

import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.RecordWriter
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
Expand Down Expand Up @@ -166,6 +169,80 @@ object CSVRelation extends Logging {
if (nonEmptyLines.hasNext) nonEmptyLines.drop(1)
}
}

def baseRdd(
sparkSession: SparkSession,
options: CSVOptions,
inputPaths: Seq[String]): RDD[String] = {
readText(sparkSession, options, inputPaths.mkString(","))
}

def tokenRdd(
options: CSVOptions,
header: Array[String],
rdd: RDD[String]): RDD[Array[String]] = {
val firstLine = if (options.headerFlag) findFirstLine(options, rdd) else null
univocityTokenizer(rdd, firstLine, options)
}

/**
* Returns the first line of the first non-empty file in path
*/
def findFirstLine(options: CSVOptions, rdd: RDD[String]): String = {
if (options.isCommentSet) {
val comment = options.comment.toString
rdd.filter { line =>
line.trim.nonEmpty && !line.startsWith(comment)
}.first()
} else {
rdd.filter { line =>
line.trim.nonEmpty
}.first()
}
}

def readText(
sparkSession: SparkSession,
options: CSVOptions,
location: String): RDD[String] = {
if (Charset.forName(options.charset) == StandardCharsets.UTF_8) {
sparkSession.sparkContext.textFile(location)
} else {
val charset = options.charset
sparkSession.sparkContext
.hadoopFile[LongWritable, Text, TextInputFormat](location)
.mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset)))
}
}

def verifySchema(schema: StructType): Unit = {
def verifyType(dataType: DataType): Unit = dataType match {
case ByteType | ShortType | IntegerType | LongType | FloatType |
DoubleType | BooleanType | _: DecimalType | TimestampType |
DateType | StringType =>

case udt: UserDefinedType[_] => verifyType(udt.sqlType)

case _ =>
throw new UnsupportedOperationException(
s"CSV data source does not support ${dataType.simpleString} data type.")
}

schema.foreach(field => verifyType(field.dataType))
}

def getHeader(rdd: RDD[String], csvOptions: CSVOptions): Array[String] = {
val firstLine = findFirstLine(csvOptions, rdd)
val firstRow = new CsvReader(csvOptions).parseLine(firstLine)

if (csvOptions.headerFlag) {
firstRow.zipWithIndex.map { case (value, index) =>
if (value == null || value.isEmpty || value == csvOptions.nullValue) s"_c$index" else value
}
} else {
firstRow.zipWithIndex.map { case (value, index) => s"_c$index" }
}
}
}

private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory {
Expand Down
Loading