diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index a585cbed2551..b36ab3238703 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -90,6 +90,7 @@ private[csv] object CSVInferSchema { // DecimalTypes have different precisions and scales, so we try to find the common type. findTightestCommonType(typeSoFar, tryParseDecimal(field, options)).getOrElse(StringType) case DoubleType => tryParseDouble(field, options) + case DateType => tryParseDate(field, options) case TimestampType => tryParseTimestamp(field, options) case BooleanType => tryParseBoolean(field, options) case StringType => StringType @@ -140,14 +141,23 @@ private[csv] object CSVInferSchema { private def tryParseDouble(field: String, options: CSVOptions): DataType = { if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field, options)) { DoubleType + } else { + tryParseDate(field, options) + } + } + + private def tryParseDate(field: String, options: CSVOptions): DataType = { + // This case infers a custom `dateFormat` is set. + if ((allCatch opt options.dateFormatter.parse(field)).isDefined) { + DateType } else { tryParseTimestamp(field, options) } } private def tryParseTimestamp(field: String, options: CSVOptions): DataType = { - // This case infers a custom `dataFormat` is set. - if ((allCatch opt options.timestampFormat.parse(field)).isDefined) { + // This case infers a custom `timestampFormat` is set. + if ((allCatch opt options.timestampFormatter.parse(field)).isDefined) { TimestampType } else if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { // We keep this for backwards compatibility. @@ -216,6 +226,8 @@ private[csv] object CSVInferSchema { } else { Some(DecimalType(range + scale, scale)) } + // By design 'TimestampType' (8 bytes) is larger than 'DateType' (4 bytes). + case (t1: DateType, t2: TimestampType) => Some(TimestampType) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index c16790630ce1..83103a0b8a83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.StandardCharsets +import java.time.format.{DateTimeFormatter, ResolverStyle} import java.util.{Locale, TimeZone} import com.univocity.parsers.csv.{CsvParserSettings, CsvWriterSettings, UnescapedQuoteHandling} @@ -150,6 +151,16 @@ class CSVOptions( val isCommentSet = this.comment != '\u0000' + lazy val dateFormatter: DateTimeFormatter = { + DateTimeFormatter.ofPattern(dateFormat.getPattern) + .withLocale(Locale.US).withZone(timeZone.toZoneId).withResolverStyle(ResolverStyle.SMART) + } + + lazy val timestampFormatter: DateTimeFormatter = { + DateTimeFormatter.ofPattern(timestampFormat.getPattern) + .withLocale(Locale.US).withZone(timeZone.toZoneId).withResolverStyle(ResolverStyle.SMART) + } + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat diff --git a/sql/core/src/test/resources/test-data/dates-and-timestamps.csv b/sql/core/src/test/resources/test-data/dates-and-timestamps.csv new file mode 100644 index 000000000000..0a9a4c2f8566 --- /dev/null +++ b/sql/core/src/test/resources/test-data/dates-and-timestamps.csv @@ -0,0 +1,4 @@ +timestamp,date +26/08/2015 22:31:46.913,27/09/2015 +27/10/2014 22:33:31.601,26/12/2016 +28/01/2016 22:33:52.888,28/01/2017 \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 661742087112..d1a8822ca025 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -59,13 +59,21 @@ class CSVInferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == expectedTypeOne) } - test("Timestamp field types are inferred correctly via custom data format") { + test("Timestamp field types are inferred correctly via custom date format") { var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) options = new CSVOptions(Map("timestampFormat" -> "yyyy"), "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType) } + test("Date field types are inferred correctly via custom date and timestamp format") { + val options = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy", + "timestampFormat" -> "dd/MM/yyyy HH:mm:ss.SSS"), "GMT") + assert(CSVInferSchema.inferField(TimestampType, + "28/01/2017 22:31:46.913", options) == TimestampType) + assert(CSVInferSchema.inferField(DateType, "16/12/2012", options) == DateType) + } + test("Timestamp field types are inferred correctly from other types") { val options = new CSVOptions(Map.empty[String, String], "GMT") assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == StringType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 4398e547d921..b9e3c292f792 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -54,6 +54,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val simpleSparseFile = "test-data/simple_sparse.csv" private val numbersFile = "test-data/numbers.csv" private val datesFile = "test-data/dates.csv" + private val datesAndTimestampsFile = "test-data/dates-and-timestamps.csv" private val unescapedQuotesFile = "test-data/unescaped-quotes.csv" private val valueMalformedFile = "test-data/value-malformed.csv" @@ -566,6 +567,44 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(results.toSeq.map(_.toSeq) === expected) } + test("inferring timestamp types and date types via custom formats") { + val options = Map( + "header" -> "true", + "inferSchema" -> "true", + "timestampFormat" -> "dd/MM/yyyy HH:mm:ss.SSS", + "dateFormat" -> "dd/MM/yyyy") + val results = spark.read + .format("csv") + .options(options) + .load(testFile(datesAndTimestampsFile)) + assert(results.schema{0}.dataType===TimestampType) + assert(results.schema{1}.dataType===DateType) + val timestamps = spark.read + .format("csv") + .options(options) + .load(testFile(datesAndTimestampsFile)) + .select("timestamp") + .collect() + val timestampFormat = new SimpleDateFormat("dd/MM/yyyy HH:mm:ss.SSS", Locale.US) + val timestampExpected = + Seq(Seq(new Timestamp(timestampFormat.parse("26/08/2015 22:31:46.913").getTime)), + Seq(new Timestamp(timestampFormat.parse("27/10/2014 22:33:31.601").getTime)), + Seq(new Timestamp(timestampFormat.parse("28/01/2016 22:33:52.888").getTime))) + assert(timestamps.toSeq.map(_.toSeq) === timestampExpected) + val dates = spark.read + .format("csv") + .options(options) + .load(testFile(datesAndTimestampsFile)) + .select("date") + .collect() + val dateFormat = new SimpleDateFormat("dd/MM/yyyy", Locale.US) + val dateExpected = + Seq(Seq(new Date(dateFormat.parse("27/09/2015").getTime)), + Seq(new Date(dateFormat.parse("26/12/2016").getTime)), + Seq(new Date(dateFormat.parse("28/01/2017").getTime))) + assert(dates.toSeq.map(_.toSeq) === dateExpected) + } + test("load date types via custom date format") { val customSchema = new StructType(Array(StructField("date", DateType, true))) val options = Map(