Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.util.control.Exception.allCatch
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.util.DateFormatter
import org.apache.spark.sql.catalyst.util.TimestampFormatter
import org.apache.spark.sql.types._

Expand All @@ -32,6 +33,10 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable {
options.timeZone,
options.locale)

private val dateParser = DateFormatter(
options.dateFormat,
options.locale)

private val decimalParser = {
ExprUtils.getDecimalParser(options.locale)
}
Expand Down Expand Up @@ -102,6 +107,7 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable {
// DecimalTypes have different precisions and scales, so we try to find the common type.
compatibleType(typeSoFar, tryParseDecimal(field)).getOrElse(StringType)
case DoubleType => tryParseDouble(field)
case DateType => tryParseDate(field)
case TimestampType => tryParseTimestamp(field)
case BooleanType => tryParseBoolean(field)
case StringType => StringType
Expand Down Expand Up @@ -152,6 +158,15 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable {
private def tryParseDouble(field: String): DataType = {
if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field)) {
DoubleType
} else {
tryParseDate(field)
}
}

private def tryParseDate(field: String): DataType = {
// This case infers a custom `dataFormat` is set.
if ((allCatch opt dateParser.parse(field)).isDefined) {
DateType
} else {
tryParseTimestamp(field)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import org.apache.spark.sql.types._
class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper {

test("String fields types are inferred correctly from null types") {
val options = new CSVOptions(Map("timestampFormat" -> "yyyy-MM-dd HH:mm:ss"), false, "GMT")
val options = new CSVOptions(
Map("timestampFormat" -> "yyyy-MM-dd HH:mm:ss", "dateFormat" -> "yyyy-MM-dd"), false, "GMT")
val inferSchema = new CSVInferSchema(options)

assert(inferSchema.inferField(NullType, "") == NullType)
Expand All @@ -36,6 +37,8 @@ class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper {
assert(inferSchema.inferField(NullType, "60") == IntegerType)
assert(inferSchema.inferField(NullType, "3.5") == DoubleType)
assert(inferSchema.inferField(NullType, "test") == StringType)
// [SPARK-25517] added assert statement for auto inferring DateType
assert(inferSchema.inferField(NullType, "2019-03-04") == DateType)
assert(inferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType)
assert(inferSchema.inferField(NullType, "True") == BooleanType)
assert(inferSchema.inferField(NullType, "FAlSE") == BooleanType)
Expand All @@ -47,14 +50,17 @@ class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper {
}

test("String fields types are inferred correctly from other types") {
val options = new CSVOptions(Map("timestampFormat" -> "yyyy-MM-dd HH:mm:ss"), false, "GMT")
val options = new CSVOptions(
Map("timestampFormat" -> "yyyy-MM-dd HH:mm:ss", "dateFormat" -> "yyyy/MM/dd"), false, "GMT")
val inferSchema = new CSVInferSchema(options)

assert(inferSchema.inferField(LongType, "1.0") == DoubleType)
assert(inferSchema.inferField(LongType, "test") == StringType)
assert(inferSchema.inferField(IntegerType, "1.0") == DoubleType)
assert(inferSchema.inferField(DoubleType, null) == DoubleType)
assert(inferSchema.inferField(DoubleType, "test") == StringType)
// [SPARK-25517] added assert statement for auto inferring DateType
assert(inferSchema.inferField(NullType, "2019/03/04") == DateType)
assert(inferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType)
assert(inferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType)
assert(inferSchema.inferField(LongType, "True") == BooleanType)
Expand Down Expand Up @@ -123,6 +129,8 @@ class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper {

assert(inferSchema.inferField(IntegerType, "\\N") == IntegerType)
assert(inferSchema.inferField(DoubleType, "\\N") == DoubleType)
// [SPARK-25517] added assert statement for auto inferring DateType
assert(inferSchema.inferField(DateType, "\\N") == DateType)
assert(inferSchema.inferField(TimestampType, "\\N") == TimestampType)
assert(inferSchema.inferField(BooleanType, "\\N") == BooleanType)
assert(inferSchema.inferField(DecimalType(1, 1), "\\N") == DecimalType(1, 1))
Expand Down