diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 448a4732001b5..cd3c0df006f6e 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -177,7 +177,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None,
- encoding=None):
+ dropFieldIfAllNull=None, encoding=None):
"""
Loads JSON files and returns the results as a :class:`DataFrame`.
@@ -246,6 +246,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
:param samplingRatio: defines fraction of input JSON objects used for schema inferring.
If None is set, it uses the default value, ``1.0``.
+ :param dropFieldIfAllNull: whether to ignore column of all null values or empty
+ array/struct during schema inference. If None is set, it
+ uses the default value, ``false``.
>>> df1 = spark.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index 5f130af606e19..f2a48ccf4526a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -73,6 +73,9 @@ private[sql] class JSONOptions(
val columnNameOfCorruptRecord =
parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord)
+ // Whether to ignore column of all null values or empty array/struct during schema inference
+ val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false)
+
val timeZone: TimeZone = DateTimeUtils.getTimeZone(
parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 53f44888ebaff..ff066629649b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -379,6 +379,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* that should be used for parsing.
*
`samplingRatio` (default is 1.0): defines fraction of input JSON objects used
* for schema inferring.
+ * `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or
+ * empty array/struct during schema inference.
*
*
* @since 2.0.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
index a270a6451d5dd..97ed1dc35c97c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
@@ -70,7 +70,7 @@ private[sql] object JsonInferSchema {
}.fold(StructType(Nil))(
compatibleRootType(columnNameOfCorruptRecord, parseMode))
- canonicalizeType(rootType) match {
+ canonicalizeType(rootType, configOptions) match {
case Some(st: StructType) => st
case _ =>
// canonicalizeType erases all empty structs, including the only one we want to keep
@@ -176,33 +176,33 @@ private[sql] object JsonInferSchema {
}
/**
- * Convert NullType to StringType and remove StructTypes with no fields
+ * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields,
+ * drops NullTypes or converts them to StringType based on provided options.
*/
- private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match {
- case at @ ArrayType(elementType, _) =>
- for {
- canonicalType <- canonicalizeType(elementType)
- } yield {
- at.copy(canonicalType)
- }
+ private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match {
+ case at: ArrayType =>
+ canonicalizeType(at.elementType, options)
+ .map(t => at.copy(elementType = t))
case StructType(fields) =>
- val canonicalFields: Array[StructField] = for {
- field <- fields
- if field.name.length > 0
- canonicalType <- canonicalizeType(field.dataType)
- } yield {
- field.copy(dataType = canonicalType)
+ val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f =>
+ canonicalizeType(f.dataType, options)
+ .map(t => f.copy(dataType = t))
}
-
- if (canonicalFields.length > 0) {
- Some(StructType(canonicalFields))
+ // SPARK-8093: empty structs should be deleted
+ if (canonicalFields.isEmpty) {
+ None
} else {
- // per SPARK-8093: empty structs should be deleted
+ Some(StructType(canonicalFields))
+ }
+
+ case NullType =>
+ if (options.dropFieldIfAllNull) {
None
+ } else {
+ Some(StringType)
}
- case NullType => Some(StringType)
case other => Some(other)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index ae93965bc50ed..ef8dc3a325a33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -270,6 +270,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* per file
* `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
* that should be used for parsing.
+ * `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or
+ * empty array/struct during schema inference.
*
*
* @since 2.0.0
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 0db688fec9a67..0e4523bfe088c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -2408,4 +2408,53 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
spark.read.option("mode", "PERMISSIVE").option("encoding", "UTF-8").json(Seq(badJson).toDS()),
Row(badJson))
}
+
+ test("SPARK-23772 ignore column of all null values or empty array during schema inference") {
+ withTempPath { tempDir =>
+ val path = tempDir.getAbsolutePath
+
+ // primitive types
+ Seq(
+ """{"a":null, "b":1, "c":3.0}""",
+ """{"a":null, "b":null, "c":"string"}""",
+ """{"a":null, "b":null, "c":null}""")
+ .toDS().write.text(path)
+ var df = spark.read.format("json")
+ .option("dropFieldIfAllNull", true)
+ .load(path)
+ var expectedSchema = new StructType()
+ .add("b", LongType).add("c", StringType)
+ assert(df.schema === expectedSchema)
+ checkAnswer(df, Row(1, "3.0") :: Row(null, "string") :: Row(null, null) :: Nil)
+
+ // arrays
+ Seq(
+ """{"a":[2, 1], "b":[null, null], "c":null, "d":[[], [null]], "e":[[], null, [[]]]}""",
+ """{"a":[null], "b":[null], "c":[], "d":[null, []], "e":null}""",
+ """{"a":null, "b":null, "c":[], "d":null, "e":[null, [], null]}""")
+ .toDS().write.mode("overwrite").text(path)
+ df = spark.read.format("json")
+ .option("dropFieldIfAllNull", true)
+ .load(path)
+ expectedSchema = new StructType()
+ .add("a", ArrayType(LongType))
+ assert(df.schema === expectedSchema)
+ checkAnswer(df, Row(Array(2, 1)) :: Row(Array(null)) :: Row(null) :: Nil)
+
+ // structs
+ Seq(
+ """{"a":{"a1": 1, "a2":"string"}, "b":{}}""",
+ """{"a":{"a1": 2, "a2":null}, "b":{"b1":[null]}}""",
+ """{"a":null, "b":null}""")
+ .toDS().write.mode("overwrite").text(path)
+ df = spark.read.format("json")
+ .option("dropFieldIfAllNull", true)
+ .load(path)
+ expectedSchema = new StructType()
+ .add("a", StructType(StructField("a1", LongType) :: StructField("a2", StringType)
+ :: Nil))
+ assert(df.schema === expectedSchema)
+ checkAnswer(df, Row(Row(1, "string")) :: Row(Row(2, null)) :: Row(null) :: Nil)
+ }
+ }
}