diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index ff6c93ae9815..4ed672899419 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.json import java.io.ByteArrayOutputStream -import java.util.Locale import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -126,16 +125,11 @@ class JacksonParser( case VALUE_STRING => // Special case handling for NaN and Infinity. - val value = parser.getText - val lowerCaseValue = value.toLowerCase(Locale.ROOT) - if (lowerCaseValue.equals("nan") || - lowerCaseValue.equals("infinity") || - lowerCaseValue.equals("-infinity") || - lowerCaseValue.equals("inf") || - lowerCaseValue.equals("-inf")) { - value.toFloat - } else { - throw new RuntimeException(s"Cannot parse $value as FloatType.") + parser.getText match { + case "NaN" => Float.NaN + case "Infinity" => Float.PositiveInfinity + case "-Infinity" => Float.NegativeInfinity + case other => throw new RuntimeException(s"Cannot parse $other as FloatType.") } } @@ -146,16 +140,11 @@ class JacksonParser( case VALUE_STRING => // Special case handling for NaN and Infinity. - val value = parser.getText - val lowerCaseValue = value.toLowerCase(Locale.ROOT) - if (lowerCaseValue.equals("nan") || - lowerCaseValue.equals("infinity") || - lowerCaseValue.equals("-infinity") || - lowerCaseValue.equals("inf") || - lowerCaseValue.equals("-inf")) { - value.toDouble - } else { - throw new RuntimeException(s"Cannot parse $value as DoubleType.") + parser.getText match { + case "NaN" => Double.NaN + case "Infinity" => Double.PositiveInfinity + case "-Infinity" => Double.NegativeInfinity + case other => throw new RuntimeException(s"Cannot parse $other as DoubleType.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 5e7f7944bd84..e66a60d7503f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import java.util.Locale import com.fasterxml.jackson.core.JsonFactory import org.apache.hadoop.fs.{Path, PathFilter} @@ -1988,4 +1989,43 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) } } + + test("SPARK-18772: Parse special floats correctly") { + val jsons = Seq( + """{"a": "NaN"}""", + """{"a": "Infinity"}""", + """{"a": "-Infinity"}""") + + // positive cases + val checks: Seq[Double => Boolean] = Seq( + _.isNaN, + _.isPosInfinity, + _.isNegInfinity) + + Seq(FloatType, DoubleType).foreach { dt => + jsons.zip(checks).foreach { case (json, check) => + val ds = spark.read + .schema(StructType(Seq(StructField("a", dt)))) + .json(Seq(json).toDS()) + .select($"a".cast(DoubleType)).as[Double] + assert(check(ds.first())) + } + } + + // negative cases + Seq(FloatType, DoubleType).foreach { dt => + val lowerCasedJsons = jsons.map(_.toLowerCase(Locale.ROOT)) + // The special floats are case-sensitive so these cases below throw exceptions. + lowerCasedJsons.foreach { lowerCasedJson => + val e = intercept[SparkException] { + spark.read + .option("mode", "FAILFAST") + .schema(StructType(Seq(StructField("a", dt)))) + .json(Seq(lowerCasedJson).toDS()) + .collect() + } + assert(e.getMessage.contains("Cannot parse")) + } + } + } }