Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.xml

import java.io.File
import java.nio.file.Files
import java.util.UUID

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Encoders, QueryTest, Row}
Expand All @@ -35,7 +37,11 @@ import org.apache.spark.sql.types.{
StructType
}

class XmlInferSchemaSuite extends QueryTest with SharedSparkSession with TestXmlData {
class XmlInferSchemaSuite
extends QueryTest
with SharedSparkSession
with TestXmlData
with XmlSchemaInferenceCaseSensitivityTests {

private val baseOptions = Map("rowTag" -> "ROW")

Expand All @@ -50,6 +56,9 @@ class XmlInferSchemaSuite extends QueryTest with SharedSparkSession with TestXml
spark.read.options(baseOptions ++ options).xml(dataset)
}

override protected def customSQLConf
: Map[String, String] = Map.empty

// TODO: add tests for type widening
test("Type conflict in primitive field values") {
val xmlDF = readData(primitiveFieldValueTypeConflict, Map("nullValue" -> ""))
Expand Down Expand Up @@ -630,3 +639,306 @@ class XmlInferSchemaSuite extends QueryTest with SharedSparkSession with TestXml
checkAnswer(xmlDF, expectedAns)
}
}

trait XmlSchemaInferenceCaseSensitivityTests extends QueryTest {

protected def customSQLConf: Map[String, String] = Map.empty

private def writeXmlStringToFile(
xmlString: String,
dir: File,
multiline: Boolean = true,
fileName: String = UUID.randomUUID().toString): String = {
val bytes = if (multiline) xmlString.getBytes() else xmlString.filter(_ >= ' ').getBytes
Files.write(new File(dir, fileName).toPath, bytes)
dir.getCanonicalPath + s"/$fileName"
}

private val valueTagCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = {
val caseSensitiveValueTag =
"""
|<ROW>
| <a>
| 1
| <b>2</b>
| </a>
|</ROW>
|<ROW>
| <A>
| 3
| </A>
|</ROW>
|""".stripMargin
XmlSchemaInferenceCaseSensitiveTestCase(
"value tag",
caseSensitiveValueTag,
expectedCaseSensitiveSchema = new StructType()
.add("A", LongType)
.add("a", new StructType().add("_VALUE", LongType).add("b", LongType)),
expectedCaseSensitiveAns = Seq(
Row(null, Row(1, 2)),
Row(3, null)
),
expectedCaseInsensitiveSchema = new StructType()
.add("a", new StructType().add("_VALUE", LongType).add("b", LongType)),
expectedCaseInsensitiveAns = Seq(
Row(Row(1, 2)),
Row(Row(3, null))
)
)
}

// array type: a A b and A & a struct<b, c> and A struct<b, c>, a <struct B, c>
private val arrayComplexCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = {
val caseSensitiveArrayType =
"""
|<ROW>
| <a>
| 1
| <b>2</b>
| <c>3</c>
| </a>
| <A>
| 4
| </A>
|</ROW>
|<ROW>
| <A>5</A>
|</ROW>
|""".stripMargin
XmlSchemaInferenceCaseSensitiveTestCase(
"array type - simple",
caseSensitiveArrayType,
expectedCaseSensitiveSchema = new StructType()
.add("A", LongType)
.add(
"a",
new StructType()
.add("_VALUE", LongType)
.add("b", LongType)
.add("c", LongType)
),
expectedCaseSensitiveAns = Seq(
Row(4, Row(1, 2, 3)),
Row(5, null)
),
expectedCaseInsensitiveSchema = new StructType()
.add(
"a",
ArrayType(
new StructType()
.add("_VALUE", LongType)
.add("b", LongType)
.add("c", LongType)
)
),
expectedCaseInsensitiveAns = Seq(
Row(List(Row(1, 2, 3), Row(4, null, null))),
Row(List(Row(5, null, null)))
)
)
}

private val arraySimpleCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = {
val caseSensitiveArrayType =
"""
|<ROW>
| <a>
| <b>1</b>
| <c>2</c>
| </a>
| <A>
| <B>3</B>
| <c>4</c>
| </A>
|</ROW>
|""".stripMargin
XmlSchemaInferenceCaseSensitiveTestCase(
"array type - complex",
caseSensitiveArrayType,
expectedCaseSensitiveSchema = new StructType()
.add("A", new StructType().add("B", LongType).add("c", LongType))
.add("a", new StructType().add("b", LongType).add("c", LongType)),
expectedCaseSensitiveAns = Seq(
Row(Row(3, 4), Row(1, 2))
),
expectedCaseInsensitiveSchema = new StructType()
.add("a", ArrayType(new StructType().add("b", LongType).add("c", LongType))),
expectedCaseInsensitiveAns = Seq(
Row(List(Row(1, 2), Row(3, 4)))
)
)
}

private val primitiveTypeCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = {
val caseInSensitivePrimitiveTypes =
"""
|<ROW>
| <a>1</a>
| <b>2</b>
| <c>3</c>
|</ROW>
|<ROW>
| <a>4</a>
| <B>5</B>
| <c>6</c>
|</ROW>
|""".stripMargin
XmlSchemaInferenceCaseSensitiveTestCase(
"primitive type",
caseInSensitivePrimitiveTypes,
expectedCaseSensitiveSchema = new StructType()
.add("B", LongType)
.add("a", LongType)
.add("b", LongType)
.add("c", LongType),
expectedCaseSensitiveAns = Seq(
Row(null, 1, 2, 3),
Row(5, 4, null, 6)
),
expectedCaseInsensitiveSchema = new StructType()
.add("a", LongType)
.add("b", LongType)
.add("c", LongType),
expectedCaseInsensitiveAns = Seq(
Row(1, 2, 3),
Row(4, 5, 6)
)
)
}

private val attributesCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = {
val caseSensitiveAttr =
"""
|<ROW attr="1">
| <a>1</a>
| <b>2</b>
| <c>3</c>
|</ROW>
|<ROW aTtr="2">
| <a>4</a>
| <b>5</b>
| <c>6</c>
|</ROW>
|""".stripMargin
XmlSchemaInferenceCaseSensitiveTestCase(
"attributes",
caseSensitiveAttr,
expectedCaseSensitiveSchema = new StructType()
.add("_aTtr", LongType)
.add("_attr", LongType)
.add("a", LongType)
.add("b", LongType)
.add("c", LongType),
expectedCaseSensitiveAns = Seq(
Row(null, 1, 1, 2, 3),
Row(2, null, 4, 5, 6)
),
expectedCaseInsensitiveSchema = new StructType()
.add("_attr", LongType)
.add("a", LongType)
.add("b", LongType)
.add("c", LongType),
expectedCaseInsensitiveAns = Seq(
Row(1, 1, 2, 3),
Row(2, 4, 5, 6)
)
)
}

// struct: A struct<a, c> and a <A, c>
private val structCaseSensitivityTestcase: XmlSchemaInferenceCaseSensitiveTestCase = {
val caseSensitiveStruct =
"""
|<ROW>
| <A>
| <a>1</a>
| <c>3</c>
| </A>
|</ROW>
|<ROW>
| <a>
| <A>5</A>
| <c>7</c>
| </a>
|</ROW>
|""".stripMargin
XmlSchemaInferenceCaseSensitiveTestCase(
"struct",
caseSensitiveStruct,
expectedCaseSensitiveSchema = new StructType()
.add(
"A",
new StructType()
.add("a", LongType)
.add("c", LongType)
)
.add(
"a",
new StructType()
.add("A", LongType)
.add("c", LongType)
),
expectedCaseSensitiveAns = Seq(
Row(Row(1, 3), null),
Row(null, Row(5, 7))
),
expectedCaseInsensitiveSchema = new StructType()
.add(
"A",
new StructType()
.add("a", LongType)
.add("c", LongType)
),
expectedCaseInsensitiveAns = Seq(
Row(Row(1, 3)),
Row(Row(5, 7))
)
)
}

case class XmlSchemaInferenceCaseSensitiveTestCase(
name: String,
xmlString: String,
expectedCaseSensitiveSchema: StructType,
expectedCaseSensitiveAns: Seq[Row],
expectedCaseInsensitiveSchema: StructType,
expectedCaseInsensitiveAns: Seq[Row]
)

private val testcases: Seq[XmlSchemaInferenceCaseSensitiveTestCase] = Seq(
valueTagCaseSensitivityTestcase,
arrayComplexCaseSensitivityTestcase,
arraySimpleCaseSensitivityTestcase,
primitiveTypeCaseSensitivityTestcase,
structCaseSensitivityTestcase,
attributesCaseSensitivityTestcase
)

testcases.foreach { testcase =>
test(s"case sensitivity test - ${testcase.name}") {
withTempDir { dir =>
withSQLConf(customSQLConf.toSeq: _*) {
val baseOptions = Map("rowTag" -> "ROW")
writeXmlStringToFile(testcase.xmlString, dir)
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
val xml =
spark.read
.options(baseOptions)
.xml(dir.getCanonicalPath)
assert(xml.schema == testcase.expectedCaseInsensitiveSchema)
checkAnswer(xml, testcase.expectedCaseInsensitiveAns)
}
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
val xml = spark.read
.options(baseOptions)
.xml(dir.getCanonicalPath)
assert(xml.schema == testcase.expectedCaseSensitiveSchema)
checkAnswer(xml, testcase.expectedCaseSensitiveAns)
}
}
}
}
}
}