Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ private[csv] object CSVInferSchema {
case LongType => tryParseLong(field)
case DoubleType => tryParseDouble(field)
case TimestampType => tryParseTimestamp(field)
case BooleanType => tryParseBoolean(field)
case StringType => StringType
case other: DataType =>
throw new UnsupportedOperationException(s"Unexpected data type $other")
Expand Down Expand Up @@ -117,6 +118,14 @@ private[csv] object CSVInferSchema {
def tryParseTimestamp(field: String): DataType = {
if ((allCatch opt Timestamp.valueOf(field)).isDefined) {
TimestampType
} else {
tryParseBoolean(field)
}
}

def tryParseBoolean(field: String): DataType = {
if ((allCatch opt field.toBoolean).isDefined) {
BooleanType
} else {
stringType()
}
Expand Down
5 changes: 5 additions & 0 deletions sql/core/src/test/resources/bool.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
bool
"True"
"False"

"true"
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(NullType, "3.5") == DoubleType)
assert(CSVInferSchema.inferField(NullType, "test") == StringType)
assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType)
assert(CSVInferSchema.inferField(NullType, "True") == BooleanType)
assert(CSVInferSchema.inferField(NullType, "FAlSE") == BooleanType)
}

test("String fields types are inferred correctly from other types") {
Expand All @@ -40,6 +42,9 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(DoubleType, "test") == StringType)
assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType)
assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType)
assert(CSVInferSchema.inferField(LongType, "True") == BooleanType)
assert(CSVInferSchema.inferField(IntegerType, "FALSE") == BooleanType)
assert(CSVInferSchema.inferField(TimestampType, "FALSE") == BooleanType)
}

test("Timestamp field types are inferred correctly from other types") {
Expand All @@ -48,6 +53,11 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00") == StringType)
}

test("Boolean fields types are inferred correctly from other types") {
assert(CSVInferSchema.inferField(LongType, "Fale") == StringType)
assert(CSVInferSchema.inferField(DoubleType, "TRUEe") == StringType)
}

test("Type arrays are merged to highest common type") {
assert(
CSVInferSchema.mergeRowTypes(Array(StringType),
Expand All @@ -67,6 +77,7 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType)
assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType)
assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType)
assert(CSVInferSchema.inferField(BooleanType, "\\N", "\\N") == BooleanType)
}

test("Merging Nulltypes should yeild Nulltype.") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val emptyFile = "empty.csv"
private val commentsFile = "comments.csv"
private val disableCommentsFile = "disable_comments.csv"
private val boolFile = "bool.csv"
private val simpleSparseFile = "simple_sparse.csv"

private def testFile(fileName: String): String = {
Expand Down Expand Up @@ -112,6 +113,18 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
verifyCars(cars, withHeader = true, checkTypes = true)
}

test("test inferring booleans") {
val result = sqlContext.read
.format("csv")
.option("header", "true")
.option("inferSchema", "true")
.load(testFile(boolFile))

val expectedSchema = StructType(List(
StructField("bool", BooleanType, nullable = true)))
assert(result.schema === expectedSchema)
}

test("test with alternative delimiter and quote") {
val cars = sqlContext.read
.format("csv")
Expand Down