Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.datasources.csv

import java.math.BigDecimal
import java.text.{NumberFormat, SimpleDateFormat}
import java.text.NumberFormat
import java.util.Locale

import scala.util.control.Exception._
Expand Down Expand Up @@ -85,6 +85,7 @@ private[csv] object CSVInferSchema {
case NullType => tryParseInteger(field, options)
case IntegerType => tryParseInteger(field, options)
case LongType => tryParseLong(field, options)
case _: DecimalType => tryParseDecimal(field, options)
case DoubleType => tryParseDouble(field, options)
case TimestampType => tryParseTimestamp(field, options)
case BooleanType => tryParseBoolean(field, options)
Expand All @@ -107,10 +108,28 @@ private[csv] object CSVInferSchema {
if ((allCatch opt field.toLong).isDefined) {
LongType
} else {
tryParseDouble(field, options)
tryParseDecimal(field, options)
}
}

private def tryParseDecimal(field: String, options: CSVOptions): DataType = {
val decimalTry = allCatch opt {
// `BigDecimal` conversion can fail when the `field` is not a form of number.
val bigDecimal = new BigDecimal(field)
// Because many other formats do not support decimal, it reduces the cases for
// decimals by disallowing values having scale (eg. `1.1`).
if (bigDecimal.scale <= 0) {
// `DecimalType` conversion can fail when
// 1. The precision is bigger than 38.
// 2. scale is bigger than precision.
DecimalType(bigDecimal.precision, bigDecimal.scale)
} else {
tryParseDouble(field, options)
}
}
decimalTry.getOrElse(tryParseDouble(field, options))
}

private def tryParseDouble(field: String, options: CSVOptions): DataType = {
if ((allCatch opt field.toDouble).isDefined) {
DoubleType
Expand Down Expand Up @@ -170,6 +189,33 @@ private[csv] object CSVInferSchema {
val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
Some(numericPrecedence(index))

// These two cases below deal with when `DecimalType` is larger than `IntegralType`.
case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) =>
Some(t2)
case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) =>
Some(t1)

// These two cases below deal with when `IntegralType` is larger than `DecimalType`.
case (t1: IntegralType, t2: DecimalType) =>
findTightestCommonType(DecimalType.forType(t1), t2)
case (t1: DecimalType, t2: IntegralType) =>
findTightestCommonType(t1, DecimalType.forType(t2))

// Double support larger range than fixed decimal, DecimalType.Maximum should be enough
// in most case, also have better precision.
case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
Some(DoubleType)

case (t1: DecimalType, t2: DecimalType) =>
val scale = math.max(t1.scale, t2.scale)
val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
if (range + scale > 38) {
// DecimalType can't support precision > 38
Some(DoubleType)
} else {
Some(DecimalType(range + scale, scale))
}

case _ => None
}
}
Expand Down
7 changes: 7 additions & 0 deletions sql/core/src/test/resources/decimal.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
~ decimal field has integer, integer and decimal values. The last value cannot fit to a long
~ long field has integer, long and integer values.
~ double field has double, double and decimal values.
decimal,long,double
1,1,0.1
1,9223372036854775807,1.0
92233720368547758070,1,92233720368547758070
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.execution.datasources.csv

import java.text.SimpleDateFormat

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._

Expand All @@ -35,6 +33,11 @@ class CSVInferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) == TimestampType)
assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType)
assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == BooleanType)

val textValueOne = Long.MaxValue.toString + "0"
val decimalValueOne = new java.math.BigDecimal(textValueOne)
val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale)
assert(CSVInferSchema.inferField(NullType, textValueOne, options) == expectedTypeOne)
}

test("String fields types are inferred correctly from other types") {
Expand All @@ -49,6 +52,11 @@ class CSVInferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType)
assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == BooleanType)
assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == BooleanType)

val textValueOne = Long.MaxValue.toString + "0"
val decimalValueOne = new java.math.BigDecimal(textValueOne)
val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale)
assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == expectedTypeOne)
}

test("Timestamp field types are inferred correctly via custom data format") {
Expand Down Expand Up @@ -94,6 +102,7 @@ class CSVInferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType)
assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType)
assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType)
assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == DecimalType(1, 1))
}

test("Merging Nulltypes should yield Nulltype.") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ import java.nio.charset.UnsupportedCharsetException
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat

import scala.collection.JavaConverters._

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.GzipCodec

Expand All @@ -45,6 +42,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val commentsFile = "comments.csv"
private val disableCommentsFile = "disable_comments.csv"
private val boolFile = "bool.csv"
private val decimalFile = "decimal.csv"
private val simpleSparseFile = "simple_sparse.csv"
private val numbersFile = "numbers.csv"
private val datesFile = "dates.csv"
Expand Down Expand Up @@ -135,6 +133,20 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
assert(result.schema === expectedSchema)
}

test("test inferring decimals") {
val result = sqlContext.read
.format("csv")
.option("comment", "~")
.option("header", "true")
.option("inferSchema", "true")
.load(testFile(decimalFile))
val expectedSchema = StructType(List(
StructField("decimal", DecimalType(20, 0), nullable = true),
StructField("long", LongType, nullable = true),
StructField("double", DoubleType, nullable = true)))
assert(result.schema === expectedSchema)
}

test("test with alternative delimiter and quote") {
val cars = sqlContext.read
.format("csv")
Expand Down