diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala index 286120ff40b8..618127fb6e61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.xml import java.io.File +import java.nio.file.Files +import java.util.UUID import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Encoders, QueryTest, Row} @@ -35,7 +37,11 @@ import org.apache.spark.sql.types.{ StructType } -class XmlInferSchemaSuite extends QueryTest with SharedSparkSession with TestXmlData { +class XmlInferSchemaSuite + extends QueryTest + with SharedSparkSession + with TestXmlData + with XmlSchemaInferenceCaseSensitivityTests { private val baseOptions = Map("rowTag" -> "ROW") @@ -50,6 +56,9 @@ class XmlInferSchemaSuite extends QueryTest with SharedSparkSession with TestXml spark.read.options(baseOptions ++ options).xml(dataset) } + override protected def customSQLConf + : Map[String, String] = Map.empty + // TODO: add tests for type widening test("Type conflict in primitive field values") { val xmlDF = readData(primitiveFieldValueTypeConflict, Map("nullValue" -> "")) @@ -630,3 +639,306 @@ class XmlInferSchemaSuite extends QueryTest with SharedSparkSession with TestXml checkAnswer(xmlDF, expectedAns) } } + +trait XmlSchemaInferenceCaseSensitivityTests extends QueryTest { + + protected def customSQLConf: Map[String, String] = Map.empty + + private def writeXmlStringToFile( + xmlString: String, + dir: File, + multiline: Boolean = true, + fileName: String = UUID.randomUUID().toString): String = { + val bytes = if (multiline) xmlString.getBytes() else xmlString.filter(_ >= ' ').getBytes + Files.write(new File(dir, fileName).toPath, bytes) + dir.getCanonicalPath + s"/$fileName" + } + + private val valueTagCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = { + val caseSensitiveValueTag = + """ + | + | + | 1 + | 2 + | + | + | + | + | 3 + | + | + |""".stripMargin + XmlSchemaInferenceCaseSensitiveTestCase( + "value tag", + caseSensitiveValueTag, + expectedCaseSensitiveSchema = new StructType() + .add("A", LongType) + .add("a", new StructType().add("_VALUE", LongType).add("b", LongType)), + expectedCaseSensitiveAns = Seq( + Row(null, Row(1, 2)), + Row(3, null) + ), + expectedCaseInsensitiveSchema = new StructType() + .add("a", new StructType().add("_VALUE", LongType).add("b", LongType)), + expectedCaseInsensitiveAns = Seq( + Row(Row(1, 2)), + Row(Row(3, null)) + ) + ) + } + + // array type: a A b and A & a struct and A struct, a + private val arrayComplexCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = { + val caseSensitiveArrayType = + """ + | + | + | 1 + | 2 + | 3 + | + | + | 4 + | + | + | + | 5 + | + |""".stripMargin + XmlSchemaInferenceCaseSensitiveTestCase( + "array type - simple", + caseSensitiveArrayType, + expectedCaseSensitiveSchema = new StructType() + .add("A", LongType) + .add( + "a", + new StructType() + .add("_VALUE", LongType) + .add("b", LongType) + .add("c", LongType) + ), + expectedCaseSensitiveAns = Seq( + Row(4, Row(1, 2, 3)), + Row(5, null) + ), + expectedCaseInsensitiveSchema = new StructType() + .add( + "a", + ArrayType( + new StructType() + .add("_VALUE", LongType) + .add("b", LongType) + .add("c", LongType) + ) + ), + expectedCaseInsensitiveAns = Seq( + Row(List(Row(1, 2, 3), Row(4, null, null))), + Row(List(Row(5, null, null))) + ) + ) + } + + private val arraySimpleCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = { + val caseSensitiveArrayType = + """ + | + | + | 1 + | 2 + | + | + | 3 + | 4 + | + | + |""".stripMargin + XmlSchemaInferenceCaseSensitiveTestCase( + "array type - complex", + caseSensitiveArrayType, + expectedCaseSensitiveSchema = new StructType() + .add("A", new StructType().add("B", LongType).add("c", LongType)) + .add("a", new StructType().add("b", LongType).add("c", LongType)), + expectedCaseSensitiveAns = Seq( + Row(Row(3, 4), Row(1, 2)) + ), + expectedCaseInsensitiveSchema = new StructType() + .add("a", ArrayType(new StructType().add("b", LongType).add("c", LongType))), + expectedCaseInsensitiveAns = Seq( + Row(List(Row(1, 2), Row(3, 4))) + ) + ) + } + + private val primitiveTypeCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = { + val caseInSensitivePrimitiveTypes = + """ + | + | 1 + | 2 + | 3 + | + | + | 4 + | 5 + | 6 + | + |""".stripMargin + XmlSchemaInferenceCaseSensitiveTestCase( + "primitive type", + caseInSensitivePrimitiveTypes, + expectedCaseSensitiveSchema = new StructType() + .add("B", LongType) + .add("a", LongType) + .add("b", LongType) + .add("c", LongType), + expectedCaseSensitiveAns = Seq( + Row(null, 1, 2, 3), + Row(5, 4, null, 6) + ), + expectedCaseInsensitiveSchema = new StructType() + .add("a", LongType) + .add("b", LongType) + .add("c", LongType), + expectedCaseInsensitiveAns = Seq( + Row(1, 2, 3), + Row(4, 5, 6) + ) + ) + } + + private val attributesCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = { + val caseSensitiveAttr = + """ + | + | 1 + | 2 + | 3 + | + | + | 4 + | 5 + | 6 + | + |""".stripMargin + XmlSchemaInferenceCaseSensitiveTestCase( + "attributes", + caseSensitiveAttr, + expectedCaseSensitiveSchema = new StructType() + .add("_aTtr", LongType) + .add("_attr", LongType) + .add("a", LongType) + .add("b", LongType) + .add("c", LongType), + expectedCaseSensitiveAns = Seq( + Row(null, 1, 1, 2, 3), + Row(2, null, 4, 5, 6) + ), + expectedCaseInsensitiveSchema = new StructType() + .add("_attr", LongType) + .add("a", LongType) + .add("b", LongType) + .add("c", LongType), + expectedCaseInsensitiveAns = Seq( + Row(1, 1, 2, 3), + Row(2, 4, 5, 6) + ) + ) + } + + // struct: A struct and a + private val structCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = { + val caseSensitiveStruct = + """ + | + | + | 1 + | 3 + | + | + | + | + | 5 + | 7 + | + | + |""".stripMargin + XmlSchemaInferenceCaseSensitiveTestCase( + "struct", + caseSensitiveStruct, + expectedCaseSensitiveSchema = new StructType() + .add( + "A", + new StructType() + .add("a", LongType) + .add("c", LongType) + ) + .add( + "a", + new StructType() + .add("A", LongType) + .add("c", LongType) + ), + expectedCaseSensitiveAns = Seq( + Row(Row(1, 3), null), + Row(null, Row(5, 7)) + ), + expectedCaseInsensitiveSchema = new StructType() + .add( + "A", + new StructType() + .add("a", LongType) + .add("c", LongType) + ), + expectedCaseInsensitiveAns = Seq( + Row(Row(1, 3)), + Row(Row(5, 7)) + ) + ) + } + + case class XmlSchemaInferenceCaseSensitiveTestCase( + name: String, + xmlString: String, + expectedCaseSensitiveSchema: StructType, + expectedCaseSensitiveAns: Seq[Row], + expectedCaseInsensitiveSchema: StructType, + expectedCaseInsensitiveAns: Seq[Row] + ) + + private val testcases: Seq[XmlSchemaInferenceCaseSensitiveTestCase] = Seq( + valueTagCaseSensitivityTestcase, + arrayComplexCaseSensitivityTestcase, + arraySimpleCaseSensitivityTestcase, + primitiveTypeCaseSensitivityTestcase, + structCaseSensitivityTestcase, + attributesCaseSensitivityTestcase + ) + + testcases.foreach { testcase => + test(s"case sensitivity test - ${testcase.name}") { + withTempDir { dir => + withSQLConf(customSQLConf.toSeq: _*) { + val baseOptions = Map("rowTag" -> "ROW") + writeXmlStringToFile(testcase.xmlString, dir) + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val xml = + spark.read + .options(baseOptions) + .xml(dir.getCanonicalPath) + assert(xml.schema == testcase.expectedCaseInsensitiveSchema) + checkAnswer(xml, testcase.expectedCaseInsensitiveAns) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val xml = spark.read + .options(baseOptions) + .xml(dir.getCanonicalPath) + assert(xml.schema == testcase.expectedCaseSensitiveSchema) + checkAnswer(xml, testcase.expectedCaseSensitiveAns) + } + } + } + } + } +} +