From c567dccc487d9f495eb00ca412d5b48d7899f401 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 8 Nov 2018 15:31:06 +0300 Subject: [PATCH 01/16] Add a test --- .../expressions/CsvExpressionsSuite.scala | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index d006197bd567..c64fd1532162 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import java.util.Calendar +import java.text.{DecimalFormat, DecimalFormatSymbols} +import java.util.{Calendar, Locale} import org.scalatest.exceptions.TestFailedException @@ -209,4 +210,17 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P "2015-12-31T16:00:00" ) } + + test("parse decimals using locale") { + Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach { langTag => + val schema = new StructType().add("d", DecimalType(10, 5)) + val options = Map("locale" -> langTag, "sep" -> "|") + val expected = Decimal(1000.001, 10, 5) + val df = new DecimalFormat("", new DecimalFormatSymbols(Locale.forLanguageTag(langTag))) + val input = df.format(expected.toBigDecimal) + checkEvaluation( + CsvToStructs(schema, options, Literal.create(input), gmtId), + InternalRow(expected)) + } + } } From 2b41eba3c666efc2b6860fd31e387839afe378d5 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 8 Nov 2018 15:31:24 +0300 Subject: [PATCH 02/16] Fix decimal parsing --- .../spark/sql/catalyst/csv/UnivocityParser.scala | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 46ed58ed9283..271d1205cb00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.catalyst.csv import java.io.InputStream import java.math.BigDecimal +import java.text.{DecimalFormat, DecimalFormatSymbols, NumberFormat} +import java.util.Locale import scala.util.Try import scala.util.control.NonFatal - import com.univocity.parsers.csv.CsvParser - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow @@ -104,6 +104,12 @@ class UnivocityParser( requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray } + private val decimalParser = { + val df = new DecimalFormat("", new DecimalFormatSymbols(options.locale)) + df.setParseBigDecimal(true) + df + } + /** * Create a converter which converts the string value to a value according to a desired type. * Currently, we do not support complex types (`ArrayType`, `MapType`, `StructType`). @@ -149,8 +155,8 @@ class UnivocityParser( case dt: DecimalType => (d: String) => nullSafeDatum(d, name, nullable, options) { datum => - val value = new BigDecimal(datum.replaceAll(",", "")) - Decimal(value, dt.precision, dt.scale) + val bigDecimal = decimalParser.parse(datum).asInstanceOf[BigDecimal] + Decimal(bigDecimal, dt.precision, dt.scale) } case _: TimestampType => (d: String) => From cf438ae2071f952eb72aab6dddf47c2aeb930d84 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 8 Nov 2018 15:31:35 +0300 Subject: [PATCH 03/16] Add locale option --- .../org/apache/spark/sql/catalyst/csv/CSVOptions.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index cdaaa172e836..642823582a64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -131,13 +131,16 @@ class CSVOptions( val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) + // A language tag in IETF BCP 47 format + val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) + // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale) val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) From f9438c4b94d8a1dcf8b51d087cabae5a3c420db7 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 8 Nov 2018 15:36:57 +0300 Subject: [PATCH 04/16] Updating the migration guide --- docs/sql-migration-guide-upgrade.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 50458e96f7c3..e0a229bd5b04 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -9,6 +9,8 @@ displayTitle: Spark SQL Upgrading Guide ## Upgrading From Spark SQL 2.4 to 3.0 + - Since Spark 3.0, to parse decimals in locale specific format from CSV, set the `locale` option to proper value. + - In PySpark, when creating a `SparkSession` with `SparkSession.builder.getOrCreate()`, if there is an existing `SparkContext`, the builder was trying to update the `SparkConf` of the existing `SparkContext` with configurations specified to the builder, but the `SparkContext` is shared by all `SparkSession`s, so we should not update them. Since 3.0, the builder comes to not update the configurations. This is the same behavior as Java/Scala API in 2.3 and above. If you want to update them, you need to update them prior to creating a `SparkSession`. - In Spark version 2.4 and earlier, the parser of JSON data source treats empty strings as null for some data types such as `IntegerType`. For `FloatType` and `DoubleType`, it fails on empty strings and throws exceptions. Since Spark 3.0, we disallow empty strings and will throw exceptions for data types except for `StringType` and `BinaryType`. From 3125c234182cf38b503f99e8d8095c2055186c20 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 8 Nov 2018 16:02:21 +0300 Subject: [PATCH 05/16] Fix imports --- .../org/apache/spark/sql/catalyst/csv/UnivocityParser.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 271d1205cb00..843f60568985 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.catalyst.csv import java.io.InputStream import java.math.BigDecimal -import java.text.{DecimalFormat, DecimalFormatSymbols, NumberFormat} -import java.util.Locale +import java.text.{DecimalFormat, DecimalFormatSymbols} import scala.util.Try import scala.util.control.NonFatal + import com.univocity.parsers.csv.CsvParser + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow From 2f76352b5dec2318e722bc7d76e6530d25c41c70 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 11 Nov 2018 18:25:35 +0100 Subject: [PATCH 06/16] Renaming decimalParser to decimalFormat --- .../org/apache/spark/sql/catalyst/csv/UnivocityParser.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 843f60568985..8f8e0e7b1996 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -105,7 +105,7 @@ class UnivocityParser( requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray } - private val decimalParser = { + private val decimalFormat = { val df = new DecimalFormat("", new DecimalFormatSymbols(options.locale)) df.setParseBigDecimal(true) df @@ -156,7 +156,7 @@ class UnivocityParser( case dt: DecimalType => (d: String) => nullSafeDatum(d, name, nullable, options) { datum => - val bigDecimal = decimalParser.parse(datum).asInstanceOf[BigDecimal] + val bigDecimal = decimalFormat.parse(datum).asInstanceOf[BigDecimal] Decimal(bigDecimal, dt.precision, dt.scale) } From 3dfce18280bad432a9faaa546d33e9d693a56f1e Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 11 Nov 2018 18:48:02 +0100 Subject: [PATCH 07/16] Moving the test to UnivocityParserSuite --- .../sql/catalyst/csv/UnivocityParserSuite.scala | 11 +++++++++++ .../catalyst/expressions/CsvExpressionsSuite.scala | 14 -------------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index e4e7dc2e8c0e..ceeaefaf78dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -196,4 +196,15 @@ class UnivocityParserSuite extends SparkFunSuite { assert(doubleVal2 == Double.PositiveInfinity) } + test("parse decimals using locale") { + Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach { langTag => + val strVal = "1000.001" + val decimalVal = new BigDecimal(strVal) + val decimalType = new DecimalType(10, 5) + + val options = new CSVOptions(Map("locale" -> langTag), false, "GMT") + assert(parser.makeConverter("_1", decimalType, options = options).apply(strVal) === + Decimal(decimalVal, decimalType.precision, decimalType.scale)) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index 6bba64b348f5..f5aaaec45615 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import java.text.{DecimalFormat, DecimalFormatSymbols} import java.text.SimpleDateFormat import java.util.{Calendar, Locale} @@ -227,17 +226,4 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P InternalRow(17836)) // number of days from 1970-01-01 } } - - test("parse decimals using locale") { - Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach { langTag => - val schema = new StructType().add("d", DecimalType(10, 5)) - val options = Map("locale" -> langTag, "sep" -> "|") - val expected = Decimal(1000.001, 10, 5) - val df = new DecimalFormat("", new DecimalFormatSymbols(Locale.forLanguageTag(langTag))) - val input = df.format(expected.toBigDecimal) - checkEvaluation( - CsvToStructs(schema, options, Literal.create(input), gmtId), - InternalRow(expected)) - } - } } From bdca7c482ac157cbe56c261dcba336f7b4ce0ceb Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 12 Nov 2018 17:45:32 +0100 Subject: [PATCH 08/16] Support the SQL config spark.sql.legacy.decimalParsing.enabled --- .../sql/catalyst/csv/UnivocityParser.scala | 10 ++++--- .../apache/spark/sql/internal/SQLConf.scala | 11 +++++++ .../catalyst/csv/UnivocityParserSuite.scala | 30 +++++++++++++++++-- 3 files changed, 44 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 8f8e0e7b1996..cc3bf9480a90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -105,10 +106,12 @@ class UnivocityParser( requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray } - private val decimalFormat = { + private val decimalParser = if (SQLConf.get.legacyDecimalParsing) { + (s: String) => new BigDecimal(s.replaceAll(",", "")) + } else { val df = new DecimalFormat("", new DecimalFormatSymbols(options.locale)) df.setParseBigDecimal(true) - df + (s: String) => df.parse(s).asInstanceOf[BigDecimal] } /** @@ -156,8 +159,7 @@ class UnivocityParser( case dt: DecimalType => (d: String) => nullSafeDatum(d, name, nullable, options) { datum => - val bigDecimal = decimalFormat.parse(datum).asInstanceOf[BigDecimal] - Decimal(bigDecimal, dt.precision, dt.scale) + Decimal(decimalParser(datum), dt.precision, dt.scale) } case _: TimestampType => (d: String) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 518115dafd01..cd88aa6d610e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1594,6 +1594,14 @@ object SQLConf { "WHERE, which does not follow SQL standard.") .booleanConf .createWithDefault(false) + + val LEGACY_DECIMAL_PARSING_ENABLED = buildConf("spark.sql.legacy.decimalParsing.enabled") + .internal() + .doc("If it is set to false, it enables parsing decimals in locale specific formats. " + + "To switch back to previous behaviour when parsing was performed by java.math.BigDecimal and " + + "all commas were removed from the input, set the flag to true.") + .booleanConf + .createWithDefault(false) } /** @@ -2009,6 +2017,9 @@ class SQLConf extends Serializable with Logging { def integralDivideReturnLong: Boolean = getConf(SQLConf.LEGACY_INTEGRALDIVIDE_RETURN_LONG) + def legacyDecimalParsing: Boolean = getConf(SQLConf.LEGACY_DECIMAL_PARSING_ENABLED) + + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index ceeaefaf78dc..cde11b39c44c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -18,13 +18,17 @@ package org.apache.spark.sql.catalyst.csv import java.math.BigDecimal +import java.text.{DecimalFormat, DecimalFormatSymbols} +import java.util.Locale import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class UnivocityParserSuite extends SparkFunSuite { +class UnivocityParserSuite extends SparkFunSuite with SQLHelper { private val parser = new UnivocityParser( StructType(Seq.empty), new CSVOptions(Map.empty[String, String], false, "GMT")) @@ -197,14 +201,34 @@ class UnivocityParserSuite extends SparkFunSuite { } test("parse decimals using locale") { - Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach { langTag => + def checkDecimalParsing(langTag: String): Unit = { val strVal = "1000.001" val decimalVal = new BigDecimal(strVal) val decimalType = new DecimalType(10, 5) + val expected = Decimal(decimalVal, decimalType.precision, decimalType.scale) + val df = new DecimalFormat("", new DecimalFormatSymbols(Locale.forLanguageTag(langTag))) + val input = df.format(expected.toBigDecimal) val options = new CSVOptions(Map("locale" -> langTag), false, "GMT") - assert(parser.makeConverter("_1", decimalType, options = options).apply(strVal) === + val parser = new UnivocityParser(new StructType().add("d", decimalType), options) + assert(parser.makeConverter("_1", decimalType, options = options).apply(input) === Decimal(decimalVal, decimalType.precision, decimalType.scale)) } + + withSQLConf(SQLConf.LEGACY_DECIMAL_PARSING_ENABLED.key -> "false") { + Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalParsing) + } + + withSQLConf(SQLConf.LEGACY_DECIMAL_PARSING_ENABLED.key -> "true") { + Seq("en-US", "ko-KR").foreach(checkDecimalParsing) + } + + withSQLConf(SQLConf.LEGACY_DECIMAL_PARSING_ENABLED.key -> "true") { + Seq("ru-RU").foreach { langTag => + intercept[NumberFormatException] { + checkDecimalParsing(langTag) + } + } + } } } From 8c5593ef13c3f4da5d2c746e906e8a33e46ee2a8 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 12 Nov 2018 18:05:44 +0100 Subject: [PATCH 09/16] Updating the migration guide. --- docs/sql-migration-guide-upgrade.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index e0a229bd5b04..2559423eea78 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -9,7 +9,7 @@ displayTitle: Spark SQL Upgrading Guide ## Upgrading From Spark SQL 2.4 to 3.0 - - Since Spark 3.0, to parse decimals in locale specific format from CSV, set the `locale` option to proper value. + - In Spark version 2.4 and earlier, accepted format of decimals parsed from CSV is an optional sign ('+' or '-'), followed by a sequence of zero or more decimal digits, optionally followed by a fraction, optionally followed by an exponent. Any commas were removed from the input before parsing. Since Spark 3.0, format varies and depends on locale which can be set via CSV option `locale`. The default locale is `en-US`. To switch back to previous behavior, set `spark.sql.legacy.decimalParsing.enabled` to `true`. - In PySpark, when creating a `SparkSession` with `SparkSession.builder.getOrCreate()`, if there is an existing `SparkContext`, the builder was trying to update the `SparkConf` of the existing `SparkContext` with configurations specified to the builder, but the `SparkContext` is shared by all `SparkSession`s, so we should not update them. Since 3.0, the builder comes to not update the configurations. This is the same behavior as Java/Scala API in 2.3 and above. If you want to update them, you need to update them prior to creating a `SparkSession`. From 18470b0600d2c67e1f84a52030f73737d67603d2 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 12 Nov 2018 18:10:16 +0100 Subject: [PATCH 10/16] Refactoring --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cd88aa6d610e..5b0d460794b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1598,8 +1598,8 @@ object SQLConf { val LEGACY_DECIMAL_PARSING_ENABLED = buildConf("spark.sql.legacy.decimalParsing.enabled") .internal() .doc("If it is set to false, it enables parsing decimals in locale specific formats. " + - "To switch back to previous behaviour when parsing was performed by java.math.BigDecimal and " + - "all commas were removed from the input, set the flag to true.") + "To switch back to previous behaviour when parsing was performed by java.math.BigDecimal " + + "and all commas were removed from the input, set the flag to true.") .booleanConf .createWithDefault(false) } @@ -2019,7 +2019,6 @@ class SQLConf extends Serializable with Logging { def legacyDecimalParsing: Boolean = getConf(SQLConf.LEGACY_DECIMAL_PARSING_ENABLED) - /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ From c28b79f3b4c485968df61281079114c1d6756348 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 12 Nov 2018 18:10:46 +0100 Subject: [PATCH 11/16] Removing internal --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5b0d460794b7..3b39a68399f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1596,7 +1596,6 @@ object SQLConf { .createWithDefault(false) val LEGACY_DECIMAL_PARSING_ENABLED = buildConf("spark.sql.legacy.decimalParsing.enabled") - .internal() .doc("If it is set to false, it enables parsing decimals in locale specific formats. " + "To switch back to previous behaviour when parsing was performed by java.math.BigDecimal " + "and all commas were removed from the input, set the flag to true.") From 1723da2293e098279d8070fa270dd3785c45af1c Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 12 Nov 2018 18:15:36 +0100 Subject: [PATCH 12/16] Test refactoring --- .../spark/sql/catalyst/csv/UnivocityParserSuite.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index cde11b39c44c..6fc1f4e40dc2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -202,8 +202,7 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { test("parse decimals using locale") { def checkDecimalParsing(langTag: String): Unit = { - val strVal = "1000.001" - val decimalVal = new BigDecimal(strVal) + val decimalVal = new BigDecimal("1000.001") val decimalType = new DecimalType(10, 5) val expected = Decimal(decimalVal, decimalType.precision, decimalType.scale) val df = new DecimalFormat("", new DecimalFormatSymbols(Locale.forLanguageTag(langTag))) @@ -211,8 +210,8 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { val options = new CSVOptions(Map("locale" -> langTag), false, "GMT") val parser = new UnivocityParser(new StructType().add("d", decimalType), options) - assert(parser.makeConverter("_1", decimalType, options = options).apply(input) === - Decimal(decimalVal, decimalType.precision, decimalType.scale)) + + assert(parser.makeConverter("_1", decimalType, options = options).apply(input) === expected) } withSQLConf(SQLConf.LEGACY_DECIMAL_PARSING_ENABLED.key -> "false") { From 6cdafa5f9c0c79e1c217c61cad242f8aebee2f27 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 13 Nov 2018 21:52:14 +0100 Subject: [PATCH 13/16] Added a test for inferring the decimal type --- .../catalyst/expressions/CsvExpressionsSuite.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index f5aaaec45615..6d4728c0d991 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.text.{DecimalFormat, DecimalFormatSymbols} import java.text.SimpleDateFormat import java.util.{Calendar, Locale} @@ -226,4 +227,15 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P InternalRow(17836)) // number of days from 1970-01-01 } } + + test("inferring the decimal type using locale") { + Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach { langTag => + val options = Map("locale" -> langTag, "sep" -> "|") + val expected = Decimal(1000.001, 10, 5) + val df = new DecimalFormat("", new DecimalFormatSymbols(Locale.forLanguageTag(langTag))) + val input = df.format(expected.toBigDecimal) + + checkEvaluation(SchemaOfCsv(Literal.create(input), options), "struct<_c0:decimal(10, 5)>") + } + } } From 14b5109bfba9111a9d4ecb463713a1526b1e3fb1 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 14 Nov 2018 09:37:47 +0100 Subject: [PATCH 14/16] Inferring decimals from CSV --- .../spark/sql/catalyst/csv/CSVExprUtils.scala | 22 +++ .../sql/catalyst/csv/CSVInferSchema.scala | 72 ++++---- .../sql/catalyst/csv/UnivocityParser.scala | 10 +- .../catalyst/expressions/csvExpressions.scala | 5 +- .../catalyst/csv/CSVInferSchemaSuite.scala | 157 ++++++++++++------ .../expressions/CsvExpressionsSuite.scala | 12 -- .../datasources/csv/CSVDataSource.scala | 4 +- 7 files changed, 174 insertions(+), 108 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala index bbe27831f01d..5c211c0d4ec3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.catalyst.csv +import java.math.BigDecimal +import java.text.{DecimalFormat, DecimalFormatSymbols, ParsePosition} +import java.util.Locale + object CSVExprUtils { /** * Filter ignorable rows for CSV iterator (lines empty and starting with `comment`). @@ -79,4 +83,22 @@ object CSVExprUtils { throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str") } } + + def getDecimalParser(useLegacyParser: Boolean, locale: Locale): String => java.math.BigDecimal = { + if (useLegacyParser) { + (s: String) => new BigDecimal(s.replaceAll(",", "")) + } else { + val df = new DecimalFormat("", new DecimalFormatSymbols(locale)) + df.setParseBigDecimal(true) + (s: String) => { + val pos = new ParsePosition(0) + val result = df.parse(s, pos).asInstanceOf[BigDecimal] + if (pos.getIndex() != s.length() || pos.getErrorIndex() != -1) { + throw new IllegalArgumentException("Cannot parse any decimal"); + } else { + result + } + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 799e9994451b..c102b5e35226 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -17,16 +17,21 @@ package org.apache.spark.sql.catalyst.csv -import java.math.BigDecimal +import java.text.ParsePosition import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -object CSVInferSchema { +class CSVInferSchema(options: CSVOptions) extends Serializable { + + private val decimalParser = { + CSVExprUtils.getDecimalParser(SQLConf.get.legacyDecimalParsing, options.locale) + } /** * Similar to the JSON schema inference @@ -36,14 +41,13 @@ object CSVInferSchema { */ def infer( tokenRDD: RDD[Array[String]], - header: Array[String], - options: CSVOptions): StructType = { + header: Array[String]): StructType = { val fields = if (options.inferSchemaFlag) { val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) val rootTypes: Array[DataType] = - tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes) + tokenRDD.aggregate(startType)(inferRowType, mergeRowTypes) - toStructFields(rootTypes, header, options) + toStructFields(rootTypes, header) } else { // By default fields are assumed to be StringType header.map(fieldName => StructField(fieldName, StringType, nullable = true)) @@ -54,8 +58,7 @@ object CSVInferSchema { def toStructFields( fieldTypes: Array[DataType], - header: Array[String], - options: CSVOptions): Array[StructField] = { + header: Array[String]): Array[StructField] = { header.zip(fieldTypes).map { case (thisHeader, rootType) => val dType = rootType match { case _: NullType => StringType @@ -65,11 +68,10 @@ object CSVInferSchema { } } - def inferRowType(options: CSVOptions) - (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { + def inferRowType(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { var i = 0 while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing. - rowSoFar(i) = inferField(rowSoFar(i), next(i), options) + rowSoFar(i) = inferField(rowSoFar(i), next(i)) i+=1 } rowSoFar @@ -85,20 +87,20 @@ object CSVInferSchema { * Infer type of string field. Given known type Double, and a string "1", there is no * point checking if it is an Int, as the final type must be Double or higher. */ - def inferField(typeSoFar: DataType, field: String, options: CSVOptions): DataType = { + def inferField(typeSoFar: DataType, field: String): DataType = { if (field == null || field.isEmpty || field == options.nullValue) { typeSoFar } else { typeSoFar match { - case NullType => tryParseInteger(field, options) - case IntegerType => tryParseInteger(field, options) - case LongType => tryParseLong(field, options) + case NullType => tryParseInteger(field) + case IntegerType => tryParseInteger(field) + case LongType => tryParseLong(field) case _: DecimalType => // DecimalTypes have different precisions and scales, so we try to find the common type. - compatibleType(typeSoFar, tryParseDecimal(field, options)).getOrElse(StringType) - case DoubleType => tryParseDouble(field, options) - case TimestampType => tryParseTimestamp(field, options) - case BooleanType => tryParseBoolean(field, options) + compatibleType(typeSoFar, tryParseDecimal(field)).getOrElse(StringType) + case DoubleType => tryParseDouble(field) + case TimestampType => tryParseTimestamp(field) + case BooleanType => tryParseBoolean(field) case StringType => StringType case other: DataType => throw new UnsupportedOperationException(s"Unexpected data type $other") @@ -106,30 +108,30 @@ object CSVInferSchema { } } - private def isInfOrNan(field: String, options: CSVOptions): Boolean = { + private def isInfOrNan(field: String): Boolean = { field == options.nanValue || field == options.negativeInf || field == options.positiveInf } - private def tryParseInteger(field: String, options: CSVOptions): DataType = { + private def tryParseInteger(field: String): DataType = { if ((allCatch opt field.toInt).isDefined) { IntegerType } else { - tryParseLong(field, options) + tryParseLong(field) } } - private def tryParseLong(field: String, options: CSVOptions): DataType = { + private def tryParseLong(field: String): DataType = { if ((allCatch opt field.toLong).isDefined) { LongType } else { - tryParseDecimal(field, options) + tryParseDecimal(field) } } - private def tryParseDecimal(field: String, options: CSVOptions): DataType = { + private def tryParseDecimal(field: String): DataType = { val decimalTry = allCatch opt { - // `BigDecimal` conversion can fail when the `field` is not a form of number. - val bigDecimal = new BigDecimal(field) + // The conversion can fail when the `field` is not a form of number. + val bigDecimal = decimalParser(field) // Because many other formats do not support decimal, it reduces the cases for // decimals by disallowing values having scale (eg. `1.1`). if (bigDecimal.scale <= 0) { @@ -138,21 +140,21 @@ object CSVInferSchema { // 2. scale is bigger than precision. DecimalType(bigDecimal.precision, bigDecimal.scale) } else { - tryParseDouble(field, options) + tryParseDouble(field) } } - decimalTry.getOrElse(tryParseDouble(field, options)) + decimalTry.getOrElse(tryParseDouble(field)) } - private def tryParseDouble(field: String, options: CSVOptions): DataType = { - if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field, options)) { + private def tryParseDouble(field: String): DataType = { + if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field)) { DoubleType } else { - tryParseTimestamp(field, options) + tryParseTimestamp(field) } } - private def tryParseTimestamp(field: String, options: CSVOptions): DataType = { + private def tryParseTimestamp(field: String): DataType = { // This case infers a custom `dataFormat` is set. if ((allCatch opt options.timestampFormat.parse(field)).isDefined) { TimestampType @@ -160,11 +162,11 @@ object CSVInferSchema { // We keep this for backwards compatibility. TimestampType } else { - tryParseBoolean(field, options) + tryParseBoolean(field) } } - private def tryParseBoolean(field: String, options: CSVOptions): DataType = { + private def tryParseBoolean(field: String): DataType = { if ((allCatch opt field.toBoolean).isDefined) { BooleanType } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index cc3bf9480a90..bfc13554ebd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -18,8 +18,6 @@ package org.apache.spark.sql.catalyst.csv import java.io.InputStream -import java.math.BigDecimal -import java.text.{DecimalFormat, DecimalFormatSymbols} import scala.util.Try import scala.util.control.NonFatal @@ -106,12 +104,8 @@ class UnivocityParser( requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray } - private val decimalParser = if (SQLConf.get.legacyDecimalParsing) { - (s: String) => new BigDecimal(s.replaceAll(",", "")) - } else { - val df = new DecimalFormat("", new DecimalFormatSymbols(options.locale)) - df.setParseBigDecimal(true) - (s: String) => df.parse(s).asInstanceOf[BigDecimal] + private val decimalParser = { + CSVExprUtils.getDecimalParser(SQLConf.get.legacyDecimalParsing, options.locale) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index aff372b899f8..5da9c347389b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -176,8 +176,9 @@ case class SchemaOfCsv( val header = row.zipWithIndex.map { case (_, index) => s"_c$index" } val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) - val fieldTypes = CSVInferSchema.inferRowType(parsedOptions)(startType, row) - val st = StructType(CSVInferSchema.toStructFields(fieldTypes, header, parsedOptions)) + val inferSchema = new CSVInferSchema(parsedOptions) + val fieldTypes = inferSchema.inferRowType(startType, row) + val st = StructType(inferSchema.toStructFields(fieldTypes, header)) UTF8String.fromString(st.catalogString) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala index 651846d2ebcb..affd4c43e9bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala @@ -17,126 +17,185 @@ package org.apache.spark.sql.catalyst.csv +import java.text.{DecimalFormat, DecimalFormatSymbols} +import java.util.Locale + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -class CSVInferSchemaSuite extends SparkFunSuite { +class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper { test("String fields types are inferred correctly from null types") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(CSVInferSchema.inferField(NullType, "", options) == NullType) - assert(CSVInferSchema.inferField(NullType, null, options) == NullType) - assert(CSVInferSchema.inferField(NullType, "100000000000", options) == LongType) - assert(CSVInferSchema.inferField(NullType, "60", options) == IntegerType) - assert(CSVInferSchema.inferField(NullType, "3.5", options) == DoubleType) - assert(CSVInferSchema.inferField(NullType, "test", options) == StringType) - assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) == TimestampType) - assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType) - assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == BooleanType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(NullType, "") == NullType) + assert(inferSchema.inferField(NullType, null) == NullType) + assert(inferSchema.inferField(NullType, "100000000000") == LongType) + assert(inferSchema.inferField(NullType, "60") == IntegerType) + assert(inferSchema.inferField(NullType, "3.5") == DoubleType) + assert(inferSchema.inferField(NullType, "test") == StringType) + assert(inferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType) + assert(inferSchema.inferField(NullType, "True") == BooleanType) + assert(inferSchema.inferField(NullType, "FAlSE") == BooleanType) val textValueOne = Long.MaxValue.toString + "0" val decimalValueOne = new java.math.BigDecimal(textValueOne) val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale) - assert(CSVInferSchema.inferField(NullType, textValueOne, options) == expectedTypeOne) + assert(inferSchema.inferField(NullType, textValueOne) == expectedTypeOne) } test("String fields types are inferred correctly from other types") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType) - assert(CSVInferSchema.inferField(LongType, "test", options) == StringType) - assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == DoubleType) - assert(CSVInferSchema.inferField(DoubleType, null, options) == DoubleType) - assert(CSVInferSchema.inferField(DoubleType, "test", options) == StringType) - assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00", options) == TimestampType) - assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00", options) == TimestampType) - assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType) - assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == BooleanType) - assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == BooleanType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(LongType, "1.0") == DoubleType) + assert(inferSchema.inferField(LongType, "test") == StringType) + assert(inferSchema.inferField(IntegerType, "1.0") == DoubleType) + assert(inferSchema.inferField(DoubleType, null) == DoubleType) + assert(inferSchema.inferField(DoubleType, "test") == StringType) + assert(inferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType) + assert(inferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType) + assert(inferSchema.inferField(LongType, "True") == BooleanType) + assert(inferSchema.inferField(IntegerType, "FALSE") == BooleanType) + assert(inferSchema.inferField(TimestampType, "FALSE") == BooleanType) val textValueOne = Long.MaxValue.toString + "0" val decimalValueOne = new java.math.BigDecimal(textValueOne) val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale) - assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == expectedTypeOne) + assert(inferSchema.inferField(IntegerType, textValueOne) == expectedTypeOne) } test("Timestamp field types are inferred correctly via custom data format") { var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), false, "GMT") - assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) + var inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(TimestampType, "2015-08") == TimestampType) + options = new CSVOptions(Map("timestampFormat" -> "yyyy"), false, "GMT") - assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType) + inferSchema = new CSVInferSchema(options) + assert(inferSchema.inferField(TimestampType, "2015") == TimestampType) } test("Timestamp field types are inferred correctly from other types") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == StringType) - assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) == StringType) - assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == StringType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(IntegerType, "2015-08-20 14") == StringType) + assert(inferSchema.inferField(DoubleType, "2015-08-20 14:10") == StringType) + assert(inferSchema.inferField(LongType, "2015-08 14:49:00") == StringType) } test("Boolean fields types are inferred correctly from other types") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType) - assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == StringType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(LongType, "Fale") == StringType) + assert(inferSchema.inferField(DoubleType, "TRUEe") == StringType) } test("Type arrays are merged to highest common type") { + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val inferSchema = new CSVInferSchema(options) + assert( - CSVInferSchema.mergeRowTypes(Array(StringType), + inferSchema.mergeRowTypes(Array(StringType), Array(DoubleType)).deep == Array(StringType).deep) assert( - CSVInferSchema.mergeRowTypes(Array(IntegerType), + inferSchema.mergeRowTypes(Array(IntegerType), Array(LongType)).deep == Array(LongType).deep) assert( - CSVInferSchema.mergeRowTypes(Array(DoubleType), + inferSchema.mergeRowTypes(Array(DoubleType), Array(LongType)).deep == Array(DoubleType).deep) } test("Null fields are handled properly when a nullValue is specified") { var options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") - assert(CSVInferSchema.inferField(NullType, "null", options) == NullType) - assert(CSVInferSchema.inferField(StringType, "null", options) == StringType) - assert(CSVInferSchema.inferField(LongType, "null", options) == LongType) + var inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(NullType, "null") == NullType) + assert(inferSchema.inferField(StringType, "null") == StringType) + assert(inferSchema.inferField(LongType, "null") == LongType) options = new CSVOptions(Map("nullValue" -> "\\N"), false, "GMT") - assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType) - assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType) - assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType) - assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType) - assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == DecimalType(1, 1)) + inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(IntegerType, "\\N") == IntegerType) + assert(inferSchema.inferField(DoubleType, "\\N") == DoubleType) + assert(inferSchema.inferField(TimestampType, "\\N") == TimestampType) + assert(inferSchema.inferField(BooleanType, "\\N") == BooleanType) + assert(inferSchema.inferField(DecimalType(1, 1), "\\N") == DecimalType(1, 1)) } test("Merging Nulltypes should yield Nulltype.") { - val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType)) + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val inferSchema = new CSVInferSchema(options) + + val mergedNullTypes = inferSchema.mergeRowTypes(Array(NullType), Array(NullType)) assert(mergedNullTypes.deep == Array(NullType).deep) } test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), false, "GMT") - assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(TimestampType, "2015-08") == TimestampType) } test("SPARK-18877: `inferField` on DecimalType should find a common type with `typeSoFar`") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val inferSchema = new CSVInferSchema(options) // 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9). - assert(CSVInferSchema.inferField(DecimalType(3, -10), "1.19E+11", options) == + assert(inferSchema.inferField(DecimalType(3, -10), "1.19E11") == DecimalType(4, -9)) // BigDecimal("12345678901234567890.01234567890123456789") is precision 40 and scale 20. val value = "12345678901234567890.01234567890123456789" - assert(CSVInferSchema.inferField(DecimalType(3, -10), value, options) == DoubleType) + assert(inferSchema.inferField(DecimalType(3, -10), value) == DoubleType) // Seq(s"${Long.MaxValue}1", "2015-12-01 00:00:00") should be StringType - assert(CSVInferSchema.inferField(NullType, s"${Long.MaxValue}1", options) == DecimalType(20, 0)) - assert(CSVInferSchema.inferField(DecimalType(20, 0), "2015-12-01 00:00:00", options) + assert(inferSchema.inferField(NullType, s"${Long.MaxValue}1") == DecimalType(20, 0)) + assert(inferSchema.inferField(DecimalType(20, 0), "2015-12-01 00:00:00") == StringType) } test("DoubleType should be inferred when user defined nan/inf are provided") { val options = new CSVOptions(Map("nanValue" -> "nan", "negativeInf" -> "-inf", "positiveInf" -> "inf"), false, "GMT") - assert(CSVInferSchema.inferField(NullType, "nan", options) == DoubleType) - assert(CSVInferSchema.inferField(NullType, "inf", options) == DoubleType) - assert(CSVInferSchema.inferField(NullType, "-inf", options) == DoubleType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(NullType, "nan") == DoubleType) + assert(inferSchema.inferField(NullType, "inf") == DoubleType) + assert(inferSchema.inferField(NullType, "-inf") == DoubleType) + } + + test("inferring the decimal type using locale") { + def checkDecimalInfer(langTag: String, expectedType: DataType): Unit = { + val options = new CSVOptions( + parameters = Map("locale" -> langTag, "inferSchema" -> "true", "sep" -> "|"), + columnPruning = false, + defaultTimeZoneId = "GMT") + val inferSchema = new CSVInferSchema(options) + + val df = new DecimalFormat("", new DecimalFormatSymbols(Locale.forLanguageTag(langTag))) + val input = df.format(Decimal(1000001).toBigDecimal) + + assert(inferSchema.inferField(NullType, input) == expectedType) + } + + withSQLConf(SQLConf.LEGACY_DECIMAL_PARSING_ENABLED.key -> "false") { + Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalInfer(_, DecimalType(7, 0))) + } + + withSQLConf(SQLConf.LEGACY_DECIMAL_PARSING_ENABLED.key -> "true") { + Seq("en-US", "ko-KR").foreach(checkDecimalInfer(_, DecimalType(7, 0))) + } + + withSQLConf(SQLConf.LEGACY_DECIMAL_PARSING_ENABLED.key -> "true") { + Seq("ru-RU", "de-DE").foreach(checkDecimalInfer(_, StringType)) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index 6d4728c0d991..f5aaaec45615 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import java.text.{DecimalFormat, DecimalFormatSymbols} import java.text.SimpleDateFormat import java.util.{Calendar, Locale} @@ -227,15 +226,4 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P InternalRow(17836)) // number of days from 1970-01-01 } } - - test("inferring the decimal type using locale") { - Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach { langTag => - val options = Map("locale" -> langTag, "sep" -> "|") - val expected = Decimal(1000.001, 10, 5) - val df = new DecimalFormat("", new DecimalFormatSymbols(Locale.forLanguageTag(langTag))) - val input = df.format(expected.toBigDecimal) - - checkEvaluation(SchemaOfCsv(Literal.create(input), options), "struct<_c0:decimal(10, 5)>") - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 4808e8ef042d..615c6b36e121 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -135,7 +135,7 @@ object TextInputCSVDataSource extends CSVDataSource { val parser = new CsvParser(parsedOptions.asParserSettings) linesWithoutHeader.map(parser.parseLine) } - CSVInferSchema.infer(tokenRDD, header, parsedOptions) + new CSVInferSchema(parsedOptions).infer(tokenRDD, header) case _ => // If the first line could not be read, just return the empty schema. StructType(Nil) @@ -206,7 +206,7 @@ object MultiLineCSVDataSource extends CSVDataSource { new CsvParser(parsedOptions.asParserSettings)) } val sampled = CSVUtils.sample(tokenRDD, parsedOptions) - CSVInferSchema.infer(sampled, header, parsedOptions) + new CSVInferSchema(parsedOptions).infer(sampled, header) case None => // If the first row could not be read, just return the empty schema. StructType(Nil) From bab8fb26060860e1364684b92e40f996c82c674c Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 22 Nov 2018 08:05:56 +0100 Subject: [PATCH 15/16] Renaming df to decimalFormat --- .../org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala index 5c211c0d4ec3..0d8a54922ce2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala @@ -88,11 +88,11 @@ object CSVExprUtils { if (useLegacyParser) { (s: String) => new BigDecimal(s.replaceAll(",", "")) } else { - val df = new DecimalFormat("", new DecimalFormatSymbols(locale)) - df.setParseBigDecimal(true) + val decimalFormat = new DecimalFormat("", new DecimalFormatSymbols(locale)) + decimalFormat.setParseBigDecimal(true) (s: String) => { val pos = new ParsePosition(0) - val result = df.parse(s, pos).asInstanceOf[BigDecimal] + val result = decimalFormat.parse(s, pos).asInstanceOf[BigDecimal] if (pos.getIndex() != s.length() || pos.getErrorIndex() != -1) { throw new IllegalArgumentException("Cannot parse any decimal"); } else { From 0859624bffda3f08b7691d81830dfc5bdd3cc73e Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 28 Nov 2018 22:24:46 +0100 Subject: [PATCH 16/16] Removing SQL config and special handling of Locale.US --- docs/sql-migration-guide-upgrade.md | 2 -- .../spark/sql/catalyst/csv/CSVExprUtils.scala | 18 ---------------- .../sql/catalyst/csv/CSVInferSchema.scala | 6 ++---- .../sql/catalyst/csv/UnivocityParser.scala | 7 ++----- .../sql/catalyst/expressions/ExprUtils.scala | 21 +++++++++++++++++++ .../apache/spark/sql/internal/SQLConf.scala | 9 -------- .../catalyst/csv/CSVInferSchemaSuite.scala | 12 +---------- .../catalyst/csv/UnivocityParserSuite.scala | 16 +------------- 8 files changed, 27 insertions(+), 64 deletions(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 6a4f28e5420f..55838e773e4b 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -11,8 +11,6 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 3.0, the Dataset and DataFrame API `unionAll` is not deprecated any more. It is an alias for `union`. - - In Spark version 2.4 and earlier, accepted format of decimals parsed from CSV is an optional sign ('+' or '-'), followed by a sequence of zero or more decimal digits, optionally followed by a fraction, optionally followed by an exponent. Any commas were removed from the input before parsing. Since Spark 3.0, format varies and depends on locale which can be set via CSV option `locale`. The default locale is `en-US`. To switch back to previous behavior, set `spark.sql.legacy.decimalParsing.enabled` to `true`. - - In PySpark, when creating a `SparkSession` with `SparkSession.builder.getOrCreate()`, if there is an existing `SparkContext`, the builder was trying to update the `SparkConf` of the existing `SparkContext` with configurations specified to the builder, but the `SparkContext` is shared by all `SparkSession`s, so we should not update them. Since 3.0, the builder comes to not update the configurations. This is the same behavior as Java/Scala API in 2.3 and above. If you want to update them, you need to update them prior to creating a `SparkSession`. - In Spark version 2.4 and earlier, the parser of JSON data source treats empty strings as null for some data types such as `IntegerType`. For `FloatType` and `DoubleType`, it fails on empty strings and throws exceptions. Since Spark 3.0, we disallow empty strings and will throw exceptions for data types except for `StringType` and `BinaryType`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala index 0d8a54922ce2..6c982a1de9a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala @@ -83,22 +83,4 @@ object CSVExprUtils { throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str") } } - - def getDecimalParser(useLegacyParser: Boolean, locale: Locale): String => java.math.BigDecimal = { - if (useLegacyParser) { - (s: String) => new BigDecimal(s.replaceAll(",", "")) - } else { - val decimalFormat = new DecimalFormat("", new DecimalFormatSymbols(locale)) - decimalFormat.setParseBigDecimal(true) - (s: String) => { - val pos = new ParsePosition(0) - val result = decimalFormat.parse(s, pos).asInstanceOf[BigDecimal] - if (pos.getIndex() != s.length() || pos.getErrorIndex() != -1) { - throw new IllegalArgumentException("Cannot parse any decimal"); - } else { - result - } - } - } - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index c102b5e35226..94cb4b114e6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -17,20 +17,18 @@ package org.apache.spark.sql.catalyst.csv -import java.text.ParsePosition - import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion +import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ class CSVInferSchema(options: CSVOptions) extends Serializable { private val decimalParser = { - CSVExprUtils.getDecimalParser(SQLConf.get.legacyDecimalParsing, options.locale) + ExprUtils.getDecimalParser(options.locale) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 41d0dc60eb48..85e129224c91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -26,9 +26,8 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow} import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -104,9 +103,7 @@ class UnivocityParser( requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray } - private val decimalParser = { - CSVExprUtils.getDecimalParser(SQLConf.get.legacyDecimalParsing, options.locale) - } + private val decimalParser = ExprUtils.getDecimalParser(options.locale) /** * Create a converter which converts the string value to a value according to a desired type. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 89e9071324ef..3f3d6b2b63a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.expressions +import java.text.{DecimalFormat, DecimalFormatSymbols, ParsePosition} +import java.util.Locale + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType} @@ -83,4 +86,22 @@ object ExprUtils { } } } + + def getDecimalParser(locale: Locale): String => java.math.BigDecimal = { + if (locale == Locale.US) { // Special handling the default locale for backward compatibility + (s: String) => new java.math.BigDecimal(s.replaceAll(",", "")) + } else { + val decimalFormat = new DecimalFormat("", new DecimalFormatSymbols(locale)) + decimalFormat.setParseBigDecimal(true) + (s: String) => { + val pos = new ParsePosition(0) + val result = decimalFormat.parse(s, pos).asInstanceOf[java.math.BigDecimal] + if (pos.getIndex() != s.length() || pos.getErrorIndex() != -1) { + throw new IllegalArgumentException("Cannot parse any decimal"); + } else { + result + } + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d789fec9d575..7bcf21595ce5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1610,13 +1610,6 @@ object SQLConf { """ "... N more fields" placeholder.""") .intConf .createWithDefault(25) - - val LEGACY_DECIMAL_PARSING_ENABLED = buildConf("spark.sql.legacy.decimalParsing.enabled") - .doc("If it is set to false, it enables parsing decimals in locale specific formats. " + - "To switch back to previous behaviour when parsing was performed by java.math.BigDecimal " + - "and all commas were removed from the input, set the flag to true.") - .booleanConf - .createWithDefault(false) } /** @@ -2037,8 +2030,6 @@ class SQLConf extends Serializable with Logging { def maxToStringFields: Int = getConf(SQLConf.MAX_TO_STRING_FIELDS) - def legacyDecimalParsing: Boolean = getConf(SQLConf.LEGACY_DECIMAL_PARSING_ENABLED) - /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala index affd4c43e9bf..1a020e67a75b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala @@ -186,16 +186,6 @@ class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper { assert(inferSchema.inferField(NullType, input) == expectedType) } - withSQLConf(SQLConf.LEGACY_DECIMAL_PARSING_ENABLED.key -> "false") { - Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalInfer(_, DecimalType(7, 0))) - } - - withSQLConf(SQLConf.LEGACY_DECIMAL_PARSING_ENABLED.key -> "true") { - Seq("en-US", "ko-KR").foreach(checkDecimalInfer(_, DecimalType(7, 0))) - } - - withSQLConf(SQLConf.LEGACY_DECIMAL_PARSING_ENABLED.key -> "true") { - Seq("ru-RU", "de-DE").foreach(checkDecimalInfer(_, StringType)) - } + Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalInfer(_, DecimalType(7, 0))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index 6fc1f4e40dc2..7212402ef5cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -214,20 +214,6 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { assert(parser.makeConverter("_1", decimalType, options = options).apply(input) === expected) } - withSQLConf(SQLConf.LEGACY_DECIMAL_PARSING_ENABLED.key -> "false") { - Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalParsing) - } - - withSQLConf(SQLConf.LEGACY_DECIMAL_PARSING_ENABLED.key -> "true") { - Seq("en-US", "ko-KR").foreach(checkDecimalParsing) - } - - withSQLConf(SQLConf.LEGACY_DECIMAL_PARSING_ENABLED.key -> "true") { - Seq("ru-RU").foreach { langTag => - intercept[NumberFormatException] { - checkDecimalParsing(langTag) - } - } - } + Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalParsing) } }