Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.InferSchema
import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
import org.apache.spark.sql.types.StructType

/**
Expand Down Expand Up @@ -334,7 +334,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
parsedOptions.columnNameOfCorruptRecord
.getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val schema = userSpecifiedSchema.getOrElse {
InferSchema.infer(
JsonInferSchema.infer(
jsonRDD,
columnNameOfCorruptRecord,
parsedOptions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.csv

import java.nio.charset.{Charset, StandardCharsets}

import com.univocity.parsers.csv.CsvParser
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{LongWritable, Text}
Expand All @@ -28,13 +27,11 @@ import org.apache.hadoop.mapreduce._

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.functions.{length, trim}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
Expand All @@ -60,72 +57,17 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {

val csvOptions = new CSVOptions(options)
val paths = files.map(_.getPath.toString)
val lines: Dataset[String] = readText(sparkSession, csvOptions, paths)
val firstLine: String = findFirstLine(csvOptions, lines)
val firstRow = new CsvParser(csvOptions.asParserSettings).parseLine(firstLine)
val lines: Dataset[String] = createBaseDataset(sparkSession, csvOptions, paths)
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
val header = makeSafeHeader(firstRow, csvOptions, caseSensitive)

val parsedRdd: RDD[Array[String]] = CSVRelation.univocityTokenizer(
lines,
firstLine = if (csvOptions.headerFlag) firstLine else null,
params = csvOptions)
val schema = if (csvOptions.inferSchemaFlag) {
CSVInferSchema.infer(parsedRdd, header, csvOptions)
} else {
// By default fields are assumed to be StringType
val schemaFields = header.map { fieldName =>
StructField(fieldName, StringType, nullable = true)
}
StructType(schemaFields)
}
Some(schema)
}

/**
* Generates a header from the given row which is null-safe and duplicate-safe.
*/
private def makeSafeHeader(
row: Array[String],
options: CSVOptions,
caseSensitive: Boolean): Array[String] = {
if (options.headerFlag) {
val duplicates = {
val headerNames = row.filter(_ != null)
.map(name => if (caseSensitive) name else name.toLowerCase)
headerNames.diff(headerNames.distinct).distinct
}

row.zipWithIndex.map { case (value, index) =>
if (value == null || value.isEmpty || value == options.nullValue) {
// When there are empty strings or the values set in `nullValue`, put the
// index as the suffix.
s"_c$index"
} else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
// When there are case-insensitive duplicates, put the index as the suffix.
s"$value$index"
} else if (duplicates.contains(value)) {
// When there are duplicates, put the index as the suffix.
s"$value$index"
} else {
value
}
}
} else {
row.zipWithIndex.map { case (_, index) =>
// Uses default column names, "_c#" where # is its position of fields
// when header option is disabled.
s"_c$index"
}
}
Some(CSVInferSchema.infer(lines, caseSensitive, csvOptions))
}

override def prepareWrite(
sparkSession: SparkSession,
job: Job,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These removed block is all into CSVInferSchema.infer(...).

options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
verifySchema(dataSchema)
CSVUtils.verifySchema(dataSchema)
val conf = job.getConfiguration
val csvOptions = new CSVOptions(options)
csvOptions.compressionCodec.foreach { codec =>
Expand Down Expand Up @@ -155,13 +97,12 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
val csvOptions = new CSVOptions(options)
val commentPrefix = csvOptions.comment.toString

val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))

(file: PartitionedFile) => {
val lineIterator = {
val lines = {
val conf = broadcastedHadoopConf.value.value
val linesReader = new HadoopFileLinesReader(file, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
Expand All @@ -170,32 +111,21 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
}
}

// Consumes the header in the iterator.
CSVRelation.dropHeaderLine(file, lineIterator, csvOptions)

val filteredIter = lineIterator.filter { line =>
line.trim.nonEmpty && !line.startsWith(commentPrefix)
val linesWithoutHeader = if (csvOptions.headerFlag && file.start == 0) {
// Note that if there are only comments in the first block, the header would probably
// be not dropped.
CSVUtils.dropHeaderLine(lines, csvOptions)
} else {
lines
}

val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, csvOptions)
val parser = new UnivocityParser(dataSchema, requiredSchema, csvOptions)
filteredIter.flatMap(parser.parse)
}
}

/**
* Returns the first line of the first non-empty file in path
*/
private def findFirstLine(options: CSVOptions, lines: Dataset[String]): String = {
import lines.sqlContext.implicits._
val nonEmptyLines = lines.filter(length(trim($"value")) > 0)
if (options.isCommentSet) {
nonEmptyLines.filter(!$"value".startsWith(options.comment.toString)).first()
} else {
nonEmptyLines.first()
filteredLines.flatMap(parser.parse)
}
}

private def readText(
private def createBaseDataset(
sparkSession: SparkSession,
options: CSVOptions,
inputPaths: Seq[String]): Dataset[String] = {
Expand All @@ -215,22 +145,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
sparkSession.createDataset(rdd)(Encoders.STRING)
}
}

private def verifySchema(schema: StructType): Unit = {
def verifyType(dataType: DataType): Unit = dataType match {
case ByteType | ShortType | IntegerType | LongType | FloatType |
DoubleType | BooleanType | _: DecimalType | TimestampType |
DateType | StringType =>

case udt: UserDefinedType[_] => verifyType(udt.sqlType)

case _ =>
throw new UnsupportedOperationException(
s"CSV data source does not support ${dataType.simpleString} data type.")
}

schema.foreach(field => verifyType(field.dataType))
}
}

private[csv] class CsvOutputWriter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,15 @@
package org.apache.spark.sql.execution.datasources.csv

import java.math.BigDecimal
import java.text.NumberFormat
import java.util.Locale

import scala.util.control.Exception._
import scala.util.Try

import org.apache.spark.rdd.RDD
import com.univocity.parsers.csv.CsvParser

import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

private[csv] object CSVInferSchema {

Expand All @@ -39,22 +37,76 @@ private[csv] object CSVInferSchema {
* 3. Replace any null types with string type
*/
def infer(
tokenRdd: RDD[Array[String]],
header: Array[String],
csv: Dataset[String],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: csvLines

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one too, I just resembled json.InferSchema.infer ...

  def infer(
      json: RDD[String],

caseSensitive: Boolean,
options: CSVOptions): StructType = {
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val rootTypes: Array[DataType] =
tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)

val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
val dType = rootType match {
case _: NullType => StringType
case other => other
val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, options).first()
Copy link
Member Author

@HyukjinKwon HyukjinKwon Jan 25, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both behaviour of CSVUtils.filterCommentAndEmptys here and below should exactly the same up to my knowledge but I let them as are just simply to keep the behaviour for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please send a follow-up PR for this

val firstRow = new CsvParser(options.asParserSettings).parseLine(firstLine)
val header = makeSafeHeader(firstRow, caseSensitive, options)

val fields = if (options.inferSchemaFlag) {
val tokenRdd = csv.rdd.mapPartitions { iter =>
val filteredLines = CSVUtils.filterCommentAndEmpty(iter, options)
val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, options)
val parser = new CsvParser(options.asParserSettings)
linesWithoutHeader.map(parser.parseLine)
}

val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val rootTypes: Array[DataType] =
tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)

header.zip(rootTypes).map { case (thisHeader, rootType) =>
val dType = rootType match {
case _: NullType => StringType
case other => other
}
StructField(thisHeader, dType, nullable = true)
}
StructField(thisHeader, dType, nullable = true)
} else {
// By default fields are assumed to be StringType
header.map(fieldName => StructField(fieldName, StringType, nullable = true))
}

StructType(structFields)
StructType(fields)
}

/**
* Generates a header from the given row which is null-safe and duplicate-safe.
*/
private def makeSafeHeader(
row: Array[String],
caseSensitive: Boolean,
options: CSVOptions): Array[String] = {
if (options.headerFlag) {
val duplicates = {
val headerNames = row.filter(_ != null)
.map(name => if (caseSensitive) name else name.toLowerCase)
headerNames.diff(headerNames.distinct).distinct
}

row.zipWithIndex.map { case (value, index) =>
if (value == null || value.isEmpty || value == options.nullValue) {
// When there are empty strings or the values set in `nullValue`, put the
// index as the suffix.
s"_c$index"
} else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
// When there are case-insensitive duplicates, put the index as the suffix.
s"$value$index"
} else if (duplicates.contains(value)) {
// When there are duplicates, put the index as the suffix.
s"$value$index"
} else {
value
}
}
} else {
row.zipWithIndex.map { case (_, index) =>
// Uses default column names, "_c#" where # is its position of fields
// when header option is disabled.
s"_c$index"
}
}
}

private def inferRowType(options: CSVOptions)
Expand Down Expand Up @@ -215,32 +267,3 @@ private[csv] object CSVInferSchema {
case _ => None
}
}

private[csv] object CSVTypeCast {
/**
* Helper method that converts string representation of a character to actual character.
* It handles some Java escaped strings and throws exception if given string is longer than one
* character.
*/
@throws[IllegalArgumentException]
def toChar(str: String): Char = {
if (str.charAt(0) == '\\') {
str.charAt(1)
match {
case 't' => '\t'
case 'r' => '\r'
case 'b' => '\b'
case 'f' => '\f'
case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options
case '\'' => '\''
case 'u' if str == """\u0000""" => '\u0000'
case _ =>
throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str")
}
} else if (str.length == 1) {
str.charAt(0)
} else {
throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ private[csv] class CSVOptions(@transient private val parameters: CaseInsensitive
}
}

val delimiter = CSVTypeCast.toChar(
val delimiter = CSVUtils.toChar(
parameters.getOrElse("sep", parameters.getOrElse("delimiter", ",")))
private val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
val charset = parameters.getOrElse("encoding",
Expand Down
Loading