Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.datasources.csv

import java.math.BigDecimal
import java.text.NumberFormat
import java.text.{NumberFormat, SimpleDateFormat}
import java.util.Locale

import scala.util.control.Exception._
Expand All @@ -41,11 +41,10 @@ private[csv] object CSVInferSchema {
def infer(
tokenRdd: RDD[Array[String]],
header: Array[String],
nullValue: String = ""): StructType = {

options: CSVOptions): StructType = {
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val rootTypes: Array[DataType] =
tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes)
tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)

val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
val dType = rootType match {
Expand All @@ -58,11 +57,11 @@ private[csv] object CSVInferSchema {
StructType(structFields)
}

private def inferRowType(nullValue: String)
private def inferRowType(options: CSVOptions)
(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
var i = 0
while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing.
rowSoFar(i) = inferField(rowSoFar(i), next(i), nullValue)
rowSoFar(i) = inferField(rowSoFar(i), next(i), options)
i+=1
}
rowSoFar
Expand All @@ -78,53 +77,67 @@ private[csv] object CSVInferSchema {
* Infer type of string field. Given known type Double, and a string "1", there is no
* point checking if it is an Int, as the final type must be Double or higher.
*/
def inferField(typeSoFar: DataType, field: String, nullValue: String = ""): DataType = {
if (field == null || field.isEmpty || field == nullValue) {
def inferField(typeSoFar: DataType, field: String, options: CSVOptions): DataType = {
if (field == null || field.isEmpty || field == options.nullValue) {
typeSoFar
} else {
typeSoFar match {
case NullType => tryParseInteger(field)
case IntegerType => tryParseInteger(field)
case LongType => tryParseLong(field)
case DoubleType => tryParseDouble(field)
case TimestampType => tryParseTimestamp(field)
case BooleanType => tryParseBoolean(field)
case NullType => tryParseInteger(field, options)
case IntegerType => tryParseInteger(field, options)
case LongType => tryParseLong(field, options)
case DoubleType => tryParseDouble(field, options)
case TimestampType => tryParseTimestamp(field, options)
case BooleanType => tryParseBoolean(field, options)
case StringType => StringType
case other: DataType =>
throw new UnsupportedOperationException(s"Unexpected data type $other")
}
}
}

private def tryParseInteger(field: String): DataType = if ((allCatch opt field.toInt).isDefined) {
IntegerType
} else {
tryParseLong(field)
private def tryParseInteger(field: String, options: CSVOptions): DataType = {
if ((allCatch opt field.toInt).isDefined) {
IntegerType
} else {
tryParseLong(field, options)
}
}

private def tryParseLong(field: String): DataType = if ((allCatch opt field.toLong).isDefined) {
LongType
} else {
tryParseDouble(field)
private def tryParseLong(field: String, options: CSVOptions): DataType = {
if ((allCatch opt field.toLong).isDefined) {
LongType
} else {
tryParseDouble(field, options)
}
}

private def tryParseDouble(field: String): DataType = {
private def tryParseDouble(field: String, options: CSVOptions): DataType = {
if ((allCatch opt field.toDouble).isDefined) {
DoubleType
} else {
tryParseTimestamp(field)
tryParseTimestamp(field, options)
}
}

def tryParseTimestamp(field: String): DataType = {
if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) {
TimestampType
private def tryParseTimestamp(field: String, options: CSVOptions): DataType = {
if (options.dateFormat != null) {
// This case infers a custom `dataFormat` is set.
if ((allCatch opt options.dateFormat.parse(field)).isDefined) {
TimestampType
} else {
tryParseBoolean(field, options)
}
} else {
tryParseBoolean(field)
// We keep this for backwords competibility.
if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) {
TimestampType
} else {
tryParseBoolean(field, options)
}
}
}

def tryParseBoolean(field: String): DataType = {
private def tryParseBoolean(field: String, options: CSVOptions): DataType = {
if ((allCatch opt field.toBoolean).isDefined) {
BooleanType
} else {
Expand Down Expand Up @@ -177,7 +190,8 @@ private[csv] object CSVTypeCast {
datum: String,
castType: DataType,
nullable: Boolean = true,
nullValue: String = ""): Any = {
nullValue: String = "",
dateFormat: SimpleDateFormat = null): Any = {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will conflict with #11947 (this is making the last argument as a single option not for the multiple individual parameters like I did for infer() above).


if (datum == nullValue && nullable && (!castType.isInstanceOf[StringType])) {
null
Expand All @@ -195,12 +209,16 @@ private[csv] object CSVTypeCast {
case dt: DecimalType =>
val value = new BigDecimal(datum.replaceAll(",", ""))
Decimal(value, dt.precision, dt.scale)
// TODO(hossein): would be good to support other common timestamp formats
case _: TimestampType if dateFormat != null =>
// This one will lose microseconds parts.
// See https://issues.apache.org/jira/browse/SPARK-10681.
dateFormat.parse(datum).getTime * 1000L
case _: TimestampType =>
// This one will lose microseconds parts.
// See https://issues.apache.org/jira/browse/SPARK-10681.
DateTimeUtils.stringToTime(datum).getTime * 1000L
// TODO(hossein): would be good to support other common date formats
case _: DateType if dateFormat != null =>
DateTimeUtils.millisToDays(dateFormat.parse(datum).getTime)
case _: DateType =>
DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime)
case _: StringType => UTF8String.fromString(datum)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.datasources.csv

import java.nio.charset.StandardCharsets
import java.text.SimpleDateFormat

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes}
Expand Down Expand Up @@ -90,6 +91,12 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str
name.map(CompressionCodecs.getCodecClassName)
}

// Share date format object as it is expensive to parse date pattern.
val dateFormat: SimpleDateFormat = {
val dateFormat = parameters.get("dateFormat")
dateFormat.map(new SimpleDateFormat(_)).orNull
}

val maxColumns = getInt("maxColumns", 20480)

val maxCharsPerColumn = getInt("maxCharsPerColumn", 1000000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ object CSVRelation extends Logging {
indexSafeTokens(index),
field.dataType,
field.nullable,
params.nullValue)
params.nullValue,
params.dateFormat)
if (subIndex < requiredSize) {
row(subIndex) = value
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {

val parsedRdd = tokenRdd(sqlContext, csvOptions, header, paths)
val schema = if (csvOptions.inferSchemaFlag) {
CSVInferSchema.infer(parsedRdd, header, csvOptions.nullValue)
CSVInferSchema.infer(parsedRdd, header, csvOptions)
} else {
// By default fields are assumed to be StringType
val schemaFields = header.map { fieldName =>
Expand Down
4 changes: 4 additions & 0 deletions sql/core/src/test/resources/dates.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
date
26/08/2015 18:00
27/10/2014 18:30
28/01/2016 20:00
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,58 @@

package org.apache.spark.sql.execution.datasources.csv

import java.text.SimpleDateFormat

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._

class CSVInferSchemaSuite extends SparkFunSuite {

test("String fields types are inferred correctly from null types") {
assert(CSVInferSchema.inferField(NullType, "") == NullType)
assert(CSVInferSchema.inferField(NullType, null) == NullType)
assert(CSVInferSchema.inferField(NullType, "100000000000") == LongType)
assert(CSVInferSchema.inferField(NullType, "60") == IntegerType)
assert(CSVInferSchema.inferField(NullType, "3.5") == DoubleType)
assert(CSVInferSchema.inferField(NullType, "test") == StringType)
assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType)
assert(CSVInferSchema.inferField(NullType, "True") == BooleanType)
assert(CSVInferSchema.inferField(NullType, "FAlSE") == BooleanType)
val options = new CSVOptions(Map.empty[String, String])
assert(CSVInferSchema.inferField(NullType, "", options) == NullType)
assert(CSVInferSchema.inferField(NullType, null, options) == NullType)
assert(CSVInferSchema.inferField(NullType, "100000000000", options) == LongType)
assert(CSVInferSchema.inferField(NullType, "60", options) == IntegerType)
assert(CSVInferSchema.inferField(NullType, "3.5", options) == DoubleType)
assert(CSVInferSchema.inferField(NullType, "test", options) == StringType)
assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) == TimestampType)
assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType)
assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == BooleanType)
}

test("String fields types are inferred correctly from other types") {
assert(CSVInferSchema.inferField(LongType, "1.0") == DoubleType)
assert(CSVInferSchema.inferField(LongType, "test") == StringType)
assert(CSVInferSchema.inferField(IntegerType, "1.0") == DoubleType)
assert(CSVInferSchema.inferField(DoubleType, null) == DoubleType)
assert(CSVInferSchema.inferField(DoubleType, "test") == StringType)
assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType)
assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType)
assert(CSVInferSchema.inferField(LongType, "True") == BooleanType)
assert(CSVInferSchema.inferField(IntegerType, "FALSE") == BooleanType)
assert(CSVInferSchema.inferField(TimestampType, "FALSE") == BooleanType)
val options = new CSVOptions(Map.empty[String, String])
assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType)
assert(CSVInferSchema.inferField(LongType, "test", options) == StringType)
assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == DoubleType)
assert(CSVInferSchema.inferField(DoubleType, null, options) == DoubleType)
assert(CSVInferSchema.inferField(DoubleType, "test", options) == StringType)
assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00", options) == TimestampType)
assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00", options) == TimestampType)
assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType)
assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == BooleanType)
assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == BooleanType)
}

test("Timestamp field types are inferred correctly via custom data format") {
var options = new CSVOptions(Map("dateFormat" -> "yyyy-mm"))
assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType)
options = new CSVOptions(Map("dateFormat" -> "yyyy"))
assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType)
}

test("Timestamp field types are inferred correctly from other types") {
assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14") == StringType)
assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10") == StringType)
assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00") == StringType)
val options = new CSVOptions(Map.empty[String, String])
assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == StringType)
assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) == StringType)
assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == StringType)
}

test("Boolean fields types are inferred correctly from other types") {
assert(CSVInferSchema.inferField(LongType, "Fale") == StringType)
assert(CSVInferSchema.inferField(DoubleType, "TRUEe") == StringType)
val options = new CSVOptions(Map.empty[String, String])
assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType)
assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == StringType)
}

test("Type arrays are merged to highest common type") {
Expand All @@ -71,13 +84,16 @@ class CSVInferSchemaSuite extends SparkFunSuite {
}

test("Null fields are handled properly when a nullValue is specified") {
assert(CSVInferSchema.inferField(NullType, "null", "null") == NullType)
assert(CSVInferSchema.inferField(StringType, "null", "null") == StringType)
assert(CSVInferSchema.inferField(LongType, "null", "null") == LongType)
assert(CSVInferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType)
assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType)
assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType)
assert(CSVInferSchema.inferField(BooleanType, "\\N", "\\N") == BooleanType)
var options = new CSVOptions(Map("nullValue" -> "null"))
assert(CSVInferSchema.inferField(NullType, "null", options) == NullType)
assert(CSVInferSchema.inferField(StringType, "null", options) == StringType)
assert(CSVInferSchema.inferField(LongType, "null", options) == LongType)

options = new CSVOptions(Map("nullValue" -> "\\N"))
assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType)
assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType)
assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType)
assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType)
}

test("Merging Nulltypes should yield Nulltype.") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.datasources.csv

import java.io.File
import java.nio.charset.UnsupportedCharsetException
import java.sql.Timestamp
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat

import scala.collection.JavaConverters._

Expand All @@ -45,6 +46,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val disableCommentsFile = "disable_comments.csv"
private val boolFile = "bool.csv"
private val simpleSparseFile = "simple_sparse.csv"
private val datesFile = "dates.csv"
private val unescapedQuotesFile = "unescaped-quotes.csv"

private def testFile(fileName: String): String = {
Expand Down Expand Up @@ -367,6 +369,54 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
assert(results.toSeq.map(_.toSeq) === expected)
}

test("inferring timestamp types via custom date format") {
val options = Map(
"header" -> "true",
"inferSchema" -> "true",
"dateFormat" -> "dd/MM/yyyy hh:mm")
val results = sqlContext.read
.format("csv")
.options(options)
.load(testFile(datesFile))
.select("date")
.collect()

val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm")
val expected =
Seq(Seq(new Timestamp(dateFormat.parse("26/08/2015 18:00").getTime)),
Seq(new Timestamp(dateFormat.parse("27/10/2014 18:30").getTime)),
Seq(new Timestamp(dateFormat.parse("28/01/2016 20:00").getTime)))
assert(results.toSeq.map(_.toSeq) === expected)
}

test("load date types via custom date format") {
val customSchema = new StructType(Array(StructField("date", DateType, true)))
val options = Map(
"header" -> "true",
"inferSchema" -> "false",
"dateFormat" -> "dd/MM/yyyy hh:mm")
val results = sqlContext.read
.format("csv")
.options(options)
.schema(customSchema)
.load(testFile(datesFile))
.select("date")
.collect()

val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm")
val expected = Seq(
new Date(dateFormat.parse("26/08/2015 18:00").getTime),
new Date(dateFormat.parse("27/10/2014 18:30").getTime),
new Date(dateFormat.parse("28/01/2016 20:00").getTime))
val dates = results.toSeq.map(_.toSeq.head)
expected.zip(dates).foreach {
case (expectedDate, date) =>
// As it truncates the hours, minutes and etc., we only check
// if the dates (days, months and years) are the same via `toString()`.
assert(expectedDate.toString === date.toString)
}
}

test("setting comment to null disables comment support") {
val results = sqlContext.read
.format("csv")
Expand Down
Loading