diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala
index 286120ff40b8..618127fb6e61 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.xml
import java.io.File
+import java.nio.file.Files
+import java.util.UUID
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Encoders, QueryTest, Row}
@@ -35,7 +37,11 @@ import org.apache.spark.sql.types.{
StructType
}
-class XmlInferSchemaSuite extends QueryTest with SharedSparkSession with TestXmlData {
+class XmlInferSchemaSuite
+ extends QueryTest
+ with SharedSparkSession
+ with TestXmlData
+ with XmlSchemaInferenceCaseSensitivityTests {
private val baseOptions = Map("rowTag" -> "ROW")
@@ -50,6 +56,9 @@ class XmlInferSchemaSuite extends QueryTest with SharedSparkSession with TestXml
spark.read.options(baseOptions ++ options).xml(dataset)
}
+ override protected def customSQLConf
+ : Map[String, String] = Map.empty
+
// TODO: add tests for type widening
test("Type conflict in primitive field values") {
val xmlDF = readData(primitiveFieldValueTypeConflict, Map("nullValue" -> ""))
@@ -630,3 +639,306 @@ class XmlInferSchemaSuite extends QueryTest with SharedSparkSession with TestXml
checkAnswer(xmlDF, expectedAns)
}
}
+
+trait XmlSchemaInferenceCaseSensitivityTests extends QueryTest {
+
+ protected def customSQLConf: Map[String, String] = Map.empty
+
+ private def writeXmlStringToFile(
+ xmlString: String,
+ dir: File,
+ multiline: Boolean = true,
+ fileName: String = UUID.randomUUID().toString): String = {
+ val bytes = if (multiline) xmlString.getBytes() else xmlString.filter(_ >= ' ').getBytes
+ Files.write(new File(dir, fileName).toPath, bytes)
+ dir.getCanonicalPath + s"/$fileName"
+ }
+
+ private val valueTagCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = {
+ val caseSensitiveValueTag =
+ """
+ |
+ |
+ | 1
+ | 2
+ |
+ |
+ |
+ |
+ | 3
+ |
+ |
+ |""".stripMargin
+ XmlSchemaInferenceCaseSensitiveTestCase(
+ "value tag",
+ caseSensitiveValueTag,
+ expectedCaseSensitiveSchema = new StructType()
+ .add("A", LongType)
+ .add("a", new StructType().add("_VALUE", LongType).add("b", LongType)),
+ expectedCaseSensitiveAns = Seq(
+ Row(null, Row(1, 2)),
+ Row(3, null)
+ ),
+ expectedCaseInsensitiveSchema = new StructType()
+ .add("a", new StructType().add("_VALUE", LongType).add("b", LongType)),
+ expectedCaseInsensitiveAns = Seq(
+ Row(Row(1, 2)),
+ Row(Row(3, null))
+ )
+ )
+ }
+
+ // array type: a A b and A & a struct and A struct, a
+ private val arrayComplexCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = {
+ val caseSensitiveArrayType =
+ """
+ |
+ |
+ | 1
+ | 2
+ | 3
+ |
+ |
+ | 4
+ |
+ |
+ |
+ | 5
+ |
+ |""".stripMargin
+ XmlSchemaInferenceCaseSensitiveTestCase(
+ "array type - simple",
+ caseSensitiveArrayType,
+ expectedCaseSensitiveSchema = new StructType()
+ .add("A", LongType)
+ .add(
+ "a",
+ new StructType()
+ .add("_VALUE", LongType)
+ .add("b", LongType)
+ .add("c", LongType)
+ ),
+ expectedCaseSensitiveAns = Seq(
+ Row(4, Row(1, 2, 3)),
+ Row(5, null)
+ ),
+ expectedCaseInsensitiveSchema = new StructType()
+ .add(
+ "a",
+ ArrayType(
+ new StructType()
+ .add("_VALUE", LongType)
+ .add("b", LongType)
+ .add("c", LongType)
+ )
+ ),
+ expectedCaseInsensitiveAns = Seq(
+ Row(List(Row(1, 2, 3), Row(4, null, null))),
+ Row(List(Row(5, null, null)))
+ )
+ )
+ }
+
+ private val arraySimpleCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = {
+ val caseSensitiveArrayType =
+ """
+ |
+ |
+ | 1
+ | 2
+ |
+ |
+ | 3
+ | 4
+ |
+ |
+ |""".stripMargin
+ XmlSchemaInferenceCaseSensitiveTestCase(
+ "array type - complex",
+ caseSensitiveArrayType,
+ expectedCaseSensitiveSchema = new StructType()
+ .add("A", new StructType().add("B", LongType).add("c", LongType))
+ .add("a", new StructType().add("b", LongType).add("c", LongType)),
+ expectedCaseSensitiveAns = Seq(
+ Row(Row(3, 4), Row(1, 2))
+ ),
+ expectedCaseInsensitiveSchema = new StructType()
+ .add("a", ArrayType(new StructType().add("b", LongType).add("c", LongType))),
+ expectedCaseInsensitiveAns = Seq(
+ Row(List(Row(1, 2), Row(3, 4)))
+ )
+ )
+ }
+
+ private val primitiveTypeCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = {
+ val caseInSensitivePrimitiveTypes =
+ """
+ |
+ | 1
+ | 2
+ | 3
+ |
+ |
+ | 4
+ | 5
+ | 6
+ |
+ |""".stripMargin
+ XmlSchemaInferenceCaseSensitiveTestCase(
+ "primitive type",
+ caseInSensitivePrimitiveTypes,
+ expectedCaseSensitiveSchema = new StructType()
+ .add("B", LongType)
+ .add("a", LongType)
+ .add("b", LongType)
+ .add("c", LongType),
+ expectedCaseSensitiveAns = Seq(
+ Row(null, 1, 2, 3),
+ Row(5, 4, null, 6)
+ ),
+ expectedCaseInsensitiveSchema = new StructType()
+ .add("a", LongType)
+ .add("b", LongType)
+ .add("c", LongType),
+ expectedCaseInsensitiveAns = Seq(
+ Row(1, 2, 3),
+ Row(4, 5, 6)
+ )
+ )
+ }
+
+ private val attributesCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = {
+ val caseSensitiveAttr =
+ """
+ |
+ | 1
+ | 2
+ | 3
+ |
+ |
+ | 4
+ | 5
+ | 6
+ |
+ |""".stripMargin
+ XmlSchemaInferenceCaseSensitiveTestCase(
+ "attributes",
+ caseSensitiveAttr,
+ expectedCaseSensitiveSchema = new StructType()
+ .add("_aTtr", LongType)
+ .add("_attr", LongType)
+ .add("a", LongType)
+ .add("b", LongType)
+ .add("c", LongType),
+ expectedCaseSensitiveAns = Seq(
+ Row(null, 1, 1, 2, 3),
+ Row(2, null, 4, 5, 6)
+ ),
+ expectedCaseInsensitiveSchema = new StructType()
+ .add("_attr", LongType)
+ .add("a", LongType)
+ .add("b", LongType)
+ .add("c", LongType),
+ expectedCaseInsensitiveAns = Seq(
+ Row(1, 1, 2, 3),
+ Row(2, 4, 5, 6)
+ )
+ )
+ }
+
+ // struct: A struct and a
+ private val structCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = {
+ val caseSensitiveStruct =
+ """
+ |
+ |
+ | 1
+ | 3
+ |
+ |
+ |
+ |
+ | 5
+ | 7
+ |
+ |
+ |""".stripMargin
+ XmlSchemaInferenceCaseSensitiveTestCase(
+ "struct",
+ caseSensitiveStruct,
+ expectedCaseSensitiveSchema = new StructType()
+ .add(
+ "A",
+ new StructType()
+ .add("a", LongType)
+ .add("c", LongType)
+ )
+ .add(
+ "a",
+ new StructType()
+ .add("A", LongType)
+ .add("c", LongType)
+ ),
+ expectedCaseSensitiveAns = Seq(
+ Row(Row(1, 3), null),
+ Row(null, Row(5, 7))
+ ),
+ expectedCaseInsensitiveSchema = new StructType()
+ .add(
+ "A",
+ new StructType()
+ .add("a", LongType)
+ .add("c", LongType)
+ ),
+ expectedCaseInsensitiveAns = Seq(
+ Row(Row(1, 3)),
+ Row(Row(5, 7))
+ )
+ )
+ }
+
+ case class XmlSchemaInferenceCaseSensitiveTestCase(
+ name: String,
+ xmlString: String,
+ expectedCaseSensitiveSchema: StructType,
+ expectedCaseSensitiveAns: Seq[Row],
+ expectedCaseInsensitiveSchema: StructType,
+ expectedCaseInsensitiveAns: Seq[Row]
+ )
+
+ private val testcases: Seq[XmlSchemaInferenceCaseSensitiveTestCase] = Seq(
+ valueTagCaseSensitivityTestcase,
+ arrayComplexCaseSensitivityTestcase,
+ arraySimpleCaseSensitivityTestcase,
+ primitiveTypeCaseSensitivityTestcase,
+ structCaseSensitivityTestcase,
+ attributesCaseSensitivityTestcase
+ )
+
+ testcases.foreach { testcase =>
+ test(s"case sensitivity test - ${testcase.name}") {
+ withTempDir { dir =>
+ withSQLConf(customSQLConf.toSeq: _*) {
+ val baseOptions = Map("rowTag" -> "ROW")
+ writeXmlStringToFile(testcase.xmlString, dir)
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ val xml =
+ spark.read
+ .options(baseOptions)
+ .xml(dir.getCanonicalPath)
+ assert(xml.schema == testcase.expectedCaseInsensitiveSchema)
+ checkAnswer(xml, testcase.expectedCaseInsensitiveAns)
+ }
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+ val xml = spark.read
+ .options(baseOptions)
+ .xml(dir.getCanonicalPath)
+ assert(xml.schema == testcase.expectedCaseSensitiveSchema)
+ checkAnswer(xml, testcase.expectedCaseSensitiveAns)
+ }
+ }
+ }
+ }
+ }
+}
+