Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.catalyst.expressions

import java.text.{DecimalFormat, DecimalFormatSymbols, ParsePosition}
import java.util.Locale

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType}
Expand Down Expand Up @@ -83,4 +86,22 @@ object ExprUtils {
}
}
}

def getDecimalParser(locale: Locale): String => java.math.BigDecimal = {
if (locale == Locale.US) { // Special handling the default locale for backward compatibility
(s: String) => new java.math.BigDecimal(s.replaceAll(",", ""))
} else {
val decimalFormat = new DecimalFormat("", new DecimalFormatSymbols(locale))
decimalFormat.setParseBigDecimal(true)
(s: String) => {
val pos = new ParsePosition(0)
val result = decimalFormat.parse(s, pos).asInstanceOf[java.math.BigDecimal]
if (pos.getIndex() != s.length() || pos.getErrorIndex() != -1) {
throw new IllegalArgumentException("Cannot parse any decimal");
} else {
result
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@ import scala.util.parsing.combinator.RegexParsers

import com.fasterxml.jackson.core._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.json.JsonInferSchema.inferField
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -775,6 +773,9 @@ case class SchemaOfJson(
factory
}

@transient
private lazy val jsonInferSchema = new JsonInferSchema(jsonOptions)

@transient
private lazy val json = child.eval().asInstanceOf[UTF8String]

Expand All @@ -787,7 +788,7 @@ case class SchemaOfJson(
override def eval(v: InternalRow): Any = {
val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser =>
parser.nextToken()
inferField(parser, jsonOptions)
jsonInferSchema.inferField(parser)
}

UTF8String.fromString(dt.catalogString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -135,6 +136,8 @@ class JacksonParser(
}
}

private val decimalParser = ExprUtils.getDecimalParser(options.locale)

/**
* Create a converter which converts the JSON documents held by the `JsonParser`
* to a value according to a desired schema.
Expand Down Expand Up @@ -261,6 +264,9 @@ class JacksonParser(
(parser: JsonParser) => parseJsonToken[Decimal](parser, dataType) {
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) =>
Decimal(parser.getDecimalValue, dt.precision, dt.scale)
case VALUE_STRING if parser.getTextLength >= 1 =>
val bigDecimal = decimalParser(parser.getText)
Decimal(bigDecimal, dt.precision, dt.scale)
}

case st: StructType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,23 @@ package org.apache.spark.sql.catalyst.json

import java.util.Comparator

import scala.util.control.Exception.allCatch

import com.fasterxml.jackson.core._

import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

private[sql] object JsonInferSchema {
private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable {

private val decimalParser = ExprUtils.getDecimalParser(options.locale)

/**
* Infer the type of a collection of json records in three stages:
Expand All @@ -40,21 +45,20 @@ private[sql] object JsonInferSchema {
*/
def infer[T](
json: RDD[T],
configOptions: JSONOptions,
createParser: (JsonFactory, T) => JsonParser): StructType = {
val parseMode = configOptions.parseMode
val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
val parseMode = options.parseMode
val columnNameOfCorruptRecord = options.columnNameOfCorruptRecord

// In each RDD partition, perform schema inference on each row and merge afterwards.
val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode)
val typeMerger = JsonInferSchema.compatibleRootType(columnNameOfCorruptRecord, parseMode)
val mergedTypesFromPartitions = json.mapPartitions { iter =>
val factory = new JsonFactory()
configOptions.setJacksonOptions(factory)
options.setJacksonOptions(factory)
iter.flatMap { row =>
try {
Utils.tryWithResource(createParser(factory, row)) { parser =>
parser.nextToken()
Some(inferField(parser, configOptions))
Some(inferField(parser))
}
} catch {
case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match {
Expand Down Expand Up @@ -82,42 +86,25 @@ private[sql] object JsonInferSchema {
}
json.sparkContext.runJob(mergedTypesFromPartitions, foldPartition, mergeResult)

canonicalizeType(rootType, configOptions) match {
canonicalizeType(rootType, options) match {
case Some(st: StructType) => st
case _ =>
// canonicalizeType erases all empty structs, including the only one we want to keep
StructType(Nil)
}
}

private[this] val structFieldComparator = new Comparator[StructField] {
override def compare(o1: StructField, o2: StructField): Int = {
o1.name.compareTo(o2.name)
}
}

private def isSorted(arr: Array[StructField]): Boolean = {
var i: Int = 0
while (i < arr.length - 1) {
if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) {
return false
}
i += 1
}
true
}

/**
* Infer the type of a json document from the parser's token stream
*/
def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = {
def inferField(parser: JsonParser): DataType = {
import com.fasterxml.jackson.core.JsonToken._
parser.getCurrentToken match {
case null | VALUE_NULL => NullType

case FIELD_NAME =>
parser.nextToken()
inferField(parser, configOptions)
inferField(parser)

case VALUE_STRING if parser.getTextLength < 1 =>
// Zero length strings and nulls have special handling to deal
Expand All @@ -128,18 +115,25 @@ private[sql] object JsonInferSchema {
// record fields' types have been combined.
NullType

case VALUE_STRING if options.prefersDecimal =>
val decimalTry = allCatch opt {
val bigDecimal = decimalParser(parser.getText)
DecimalType(bigDecimal.precision, bigDecimal.scale)
}
decimalTry.getOrElse(StringType)
case VALUE_STRING => StringType

case START_OBJECT =>
val builder = Array.newBuilder[StructField]
while (nextUntil(parser, END_OBJECT)) {
builder += StructField(
parser.getCurrentName,
inferField(parser, configOptions),
inferField(parser),
nullable = true)
}
val fields: Array[StructField] = builder.result()
// Note: other code relies on this sorting for correctness, so don't remove it!
java.util.Arrays.sort(fields, structFieldComparator)
java.util.Arrays.sort(fields, JsonInferSchema.structFieldComparator)
StructType(fields)

case START_ARRAY =>
Expand All @@ -148,15 +142,15 @@ private[sql] object JsonInferSchema {
// the type as we pass through all JSON objects.
var elementType: DataType = NullType
while (nextUntil(parser, END_ARRAY)) {
elementType = compatibleType(
elementType, inferField(parser, configOptions))
elementType = JsonInferSchema.compatibleType(
elementType, inferField(parser))
}

ArrayType(elementType)

case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if options.primitivesAsString => StringType

case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType
case (VALUE_TRUE | VALUE_FALSE) if options.primitivesAsString => StringType

case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
import JsonParser.NumberType._
Expand All @@ -172,7 +166,7 @@ private[sql] object JsonInferSchema {
} else {
DoubleType
}
case FLOAT | DOUBLE if configOptions.prefersDecimal =>
case FLOAT | DOUBLE if options.prefersDecimal =>
val v = parser.getDecimalValue
if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
DecimalType(Math.max(v.precision(), v.scale()), v.scale())
Expand Down Expand Up @@ -217,20 +211,39 @@ private[sql] object JsonInferSchema {

case other => Some(other)
}
}

object JsonInferSchema {
val structFieldComparator = new Comparator[StructField] {
override def compare(o1: StructField, o2: StructField): Int = {
o1.name.compareTo(o2.name)
}
}

def isSorted(arr: Array[StructField]): Boolean = {
var i: Int = 0
while (i < arr.length - 1) {
if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) {
return false
}
i += 1
}
true
}

private def withCorruptField(
def withCorruptField(
struct: StructType,
other: DataType,
columnNameOfCorruptRecords: String,
parseMode: ParseMode) = parseMode match {
parseMode: ParseMode): StructType = parseMode match {
case PermissiveMode =>
// If we see any other data type at the root level, we get records that cannot be
// parsed. So, we use the struct as the data type and add the corrupt field to the schema.
if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) {
// If this given struct does not have a column used for corrupt records,
// add this field.
val newFields: Array[StructField] =
StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
// Note: other code relies on this sorting for correctness, so don't remove it!
java.util.Arrays.sort(newFields, structFieldComparator)
StructType(newFields)
Expand All @@ -253,7 +266,7 @@ private[sql] object JsonInferSchema {
/**
* Remove top-level ArrayType wrappers and merge the remaining schemas
*/
private def compatibleRootType(
def compatibleRootType(
columnNameOfCorruptRecords: String,
parseMode: ParseMode): (DataType, DataType) => DataType = {
// Since we support array of json objects at the top level,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import java.text.SimpleDateFormat
import java.text.{DecimalFormat, DecimalFormatSymbols, SimpleDateFormat}
import java.util.{Calendar, Locale}

import org.scalatest.exceptions.TestFailedException
Expand Down Expand Up @@ -765,4 +765,44 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
timeZoneId = gmtId),
expectedErrMsg = "The field for corrupt records must be string type and nullable")
}

def decimalInput(langTag: String): (Decimal, String) = {
val decimalVal = new java.math.BigDecimal("1000.001")
val decimalType = new DecimalType(10, 5)
val expected = Decimal(decimalVal, decimalType.precision, decimalType.scale)
val decimalFormat = new DecimalFormat("",
new DecimalFormatSymbols(Locale.forLanguageTag(langTag)))
val input = s"""{"d": "${decimalFormat.format(expected.toBigDecimal)}"}"""

(expected, input)
}

test("parse decimals using locale") {
def checkDecimalParsing(langTag: String): Unit = {
val schema = new StructType().add("d", DecimalType(10, 5))
val options = Map("locale" -> langTag)
val (expected, input) = decimalInput(langTag)

checkEvaluation(
JsonToStructs(schema, options, Literal.create(input), gmtId),
InternalRow(expected))
}

Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalParsing)
}

test("inferring the decimal type using locale") {
def checkDecimalInfer(langTag: String, expectedType: String): Unit = {
val options = Map("locale" -> langTag, "prefersDecimal" -> "true")
val (_, input) = decimalInput(langTag)

checkEvaluation(
SchemaOfJson(Literal.create(input), options),
expectedType)
}

Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach {
checkDecimalInfer(_, """struct<d:decimal(7,3)>""")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ object TextInputJsonDataSource extends JsonDataSource {
}.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow))

SQLExecution.withSQLConfPropagated(json.sparkSession) {
JsonInferSchema.infer(rdd, parsedOptions, rowParser)
new JsonInferSchema(parsedOptions).infer(rdd, rowParser)
}
}

Expand Down Expand Up @@ -166,7 +166,7 @@ object MultiLineJsonDataSource extends JsonDataSource {
.getOrElse(createParser(_: JsonFactory, _: PortableDataStream))

SQLExecution.withSQLConfPropagated(sparkSession) {
JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
new JsonInferSchema(parsedOptions).infer[PortableDataStream](sampled, parser)
}
}

Expand Down
Loading