diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index 11edce8140f09..199f1abd7e20f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -25,7 +25,6 @@ import javax.xml.stream.events._ import javax.xml.transform.stream.StreamSource import javax.xml.validation.Schema -import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.Try @@ -151,12 +150,7 @@ class StaxXmlParser( } val parser = StaxXmlParserUtils.filteredReader(xml) val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser) - // A structure object is an attribute-only element - // if it only consists of attributes and valueTags. - val isRootAttributesOnly = schema.fields.forall { f => - f.name == options.valueTag || f.name.startsWith(options.attributePrefix) - } - val result = Some(convertObject(parser, schema, rootAttributes, isRootAttributesOnly)) + val result = Some(convertObject(parser, schema, rootAttributes)) parser.close() result } catch { @@ -195,69 +189,60 @@ class StaxXmlParser( private[xml] def convertField( parser: XMLEventReader, dataType: DataType, + startElementName: String, attributes: Array[Attribute] = Array.empty): Any = { - def convertComplicatedType(dt: DataType, attributes: Array[Attribute]): Any = dt match { + def convertComplicatedType( + dt: DataType, + startElementName: String, + attributes: Array[Attribute]): Any = dt match { case st: StructType => convertObject(parser, st) case MapType(StringType, vt, _) => convertMap(parser, vt, attributes) - case ArrayType(st, _) => convertField(parser, st) + case ArrayType(st, _) => convertField(parser, st, startElementName) case _: StringType => - convertTo(StaxXmlParserUtils.currentStructureAsString(parser), StringType) + convertTo( + StaxXmlParserUtils.currentStructureAsString( + parser, startElementName, options), + StringType) } (parser.peek, dataType) match { - case (_: StartElement, dt: DataType) => convertComplicatedType(dt, attributes) + case (_: StartElement, dt: DataType) => + convertComplicatedType(dt, startElementName, attributes) case (_: EndElement, _: StringType) => + StaxXmlParserUtils.skipNextEndElement(parser, startElementName, options) // Empty. It's null if "" is the null value if (options.nullValue == "") { null } else { UTF8String.fromString("") } - case (_: EndElement, _: DataType) => null + case (_: EndElement, _: DataType) => + StaxXmlParserUtils.skipNextEndElement(parser, startElementName, options) + null case (c: Characters, ArrayType(st, _)) => // For `ArrayType`, it needs to return the type of element. The values are merged later. parser.next - convertTo(c.getData, st) - case (c: Characters, st: StructType) => - parser.next - parser.peek match { - case _: EndElement => - // It couldn't be an array of value tags - // as the opening tag is immediately followed by a closing tag. - if (c.isWhiteSpace) { - return null - } - val indexOpt = getFieldNameToIndex(st).get(options.valueTag) - indexOpt match { - case Some(index) => - convertTo(c.getData, st.fields(index).dataType) - case None => null - } - case _ => - val row = convertObject(parser, st) - if (!c.isWhiteSpace) { - addOrUpdate(row.toSeq(st).toArray, st, options.valueTag, c.getData, addToTail = false) - } else { - row - } - } + val value = convertTo(c.getData, st) + StaxXmlParserUtils.skipNextEndElement(parser, startElementName, options) + value + case (_: Characters, st: StructType) => + convertObject(parser, st) case (_: Characters, _: StringType) => - convertTo(StaxXmlParserUtils.currentStructureAsString(parser), StringType) + convertTo( + StaxXmlParserUtils.currentStructureAsString( + parser, startElementName, options), + StringType) case (c: Characters, _: DataType) if c.isWhiteSpace => // When `Characters` is found, we need to look further to decide // if this is really data or space between other elements. - val data = c.getData parser.next - parser.peek match { - case _: StartElement => convertComplicatedType(dataType, attributes) - case _: EndElement if data.isEmpty => null - case _: EndElement => convertTo(data, dataType) - case _ => convertField(parser, dataType, attributes) - } + convertField(parser, dataType, startElementName, attributes) case (c: Characters, dt: DataType) => + val value = convertTo(c.getData, dt) parser.next - convertTo(c.getData, dt) + StaxXmlParserUtils.skipNextEndElement(parser, startElementName, options) + value case (e: XMLEvent, dt: DataType) => throw new IllegalArgumentException( s"Failed to parse a value for data type $dt with event ${e.toString}") @@ -280,16 +265,16 @@ class StaxXmlParser( while (!shouldStop) { parser.nextEvent match { case e: StartElement => + val key = StaxXmlParserUtils.getName(e.asStartElement.getName, options) kvPairs += - (UTF8String.fromString(StaxXmlParserUtils.getName(e.asStartElement.getName, options)) -> - convertField(parser, valueType)) + (UTF8String.fromString(key) -> convertField(parser, valueType, key)) case c: Characters if !c.isWhiteSpace => // Create a value tag field for it kvPairs += // TODO: We don't support an array value tags in map yet. (UTF8String.fromString(options.valueTag) -> convertTo(c.getData, valueType)) - case _: EndElement => - shouldStop = StaxXmlParserUtils.checkEndElement(parser) + case _: EndElement | _: EndDocument => + shouldStop = true case _ => // do nothing } } @@ -321,6 +306,7 @@ class StaxXmlParser( private def convertObjectWithAttributes( parser: XMLEventReader, schema: StructType, + startElementName: String, attributes: Array[Attribute] = Array.empty): InternalRow = { // TODO: This method might have to be removed. Some logics duplicate `convertObject()` val row = new Array[Any](schema.length) @@ -329,7 +315,7 @@ class StaxXmlParser( val attributesMap = convertAttributes(attributes, schema) // Then, we read elements here. - val fieldsMap = convertField(parser, schema) match { + val fieldsMap = convertField(parser, schema, startElementName) match { case internalRow: InternalRow => Map(schema.map(_.name).zip(internalRow.toSeq(schema)): _*) case v if schema.fieldNames.contains(options.valueTag) => @@ -363,8 +349,7 @@ class StaxXmlParser( private def convertObject( parser: XMLEventReader, schema: StructType, - rootAttributes: Array[Attribute] = Array.empty, - isRootAttributesOnly: Boolean = false): InternalRow = { + rootAttributes: Array[Attribute] = Array.empty): InternalRow = { val row = new Array[Any](schema.length) val nameToIndex = getFieldNameToIndex(schema) // If there are attributes, then we process them first. @@ -388,7 +373,7 @@ class StaxXmlParser( nameToIndex.get(field) match { case Some(index) => schema(index).dataType match { case st: StructType => - row(index) = convertObjectWithAttributes(parser, st, attributes) + row(index) = convertObjectWithAttributes(parser, st, field, attributes) case ArrayType(dt: DataType, _) => val values = Option(row(index)) @@ -396,21 +381,21 @@ class StaxXmlParser( .getOrElse(ArrayBuffer.empty[Any]) val newValue = dt match { case st: StructType => - convertObjectWithAttributes(parser, st, attributes) + convertObjectWithAttributes(parser, st, field, attributes) case dt: DataType => - convertField(parser, dt) + convertField(parser, dt, field) } row(index) = values :+ newValue case dt: DataType => - row(index) = convertField(parser, dt, attributes) + row(index) = convertField(parser, dt, field, attributes) } case None => if (hasWildcard) { // Special case: there's an 'any' wildcard element that matches anything else // as a string (or array of strings, to parse multiple ones) - val newValue = convertField(parser, StringType) + val newValue = convertField(parser, StringType, field) val anyIndex = schema.fieldIndex(wildcardColName) schema(wildcardColName).dataType match { case StringType => @@ -423,19 +408,21 @@ class StaxXmlParser( } } else { StaxXmlParserUtils.skipChildren(parser) + StaxXmlParserUtils.skipNextEndElement(parser, field, options) } } } catch { case e: SparkUpgradeException => throw e case NonFatal(e) => + // TODO: we don't support partial results now badRecordException = badRecordException.orElse(Some(e)) } case c: Characters if !c.isWhiteSpace => addOrUpdate(row, schema, options.valueTag, c.getData) - case _: EndElement => - shouldStop = parseAndCheckEndElement(row, schema, parser) + case _: EndElement | _: EndDocument => + shouldStop = true case _ => // do nothing } @@ -599,24 +586,6 @@ class StaxXmlParser( } } - @tailrec - private def parseAndCheckEndElement( - row: Array[Any], - schema: StructType, - parser: XMLEventReader): Boolean = { - parser.peek match { - case _: EndElement | _: EndDocument => true - case _: StartElement => false - case c: Characters if !c.isWhiteSpace => - parser.nextEvent() - addOrUpdate(row, schema, options.valueTag, c.getData) - parseAndCheckEndElement(row, schema, parser) - case _ => - parser.nextEvent() - parseAndCheckEndElement(row, schema, parser) - } - } - private def addOrUpdate( row: Array[Any], schema: StructType, @@ -628,17 +597,14 @@ class StaxXmlParser( schema(index).dataType match { case ArrayType(elementType, _) => val value = convertTo(data, elementType) - val result = if (row(index) == null) { - ArrayBuffer(value) - } else { - val genericArrayData = row(index).asInstanceOf[GenericArrayData] - if (addToTail) { - genericArrayData.toArray(elementType) :+ value + val values = Option(row(index)) + .map(_.asInstanceOf[ArrayBuffer[Any]]) + .getOrElse(ArrayBuffer.empty[Any]) + row(index) = if (addToTail) { + values :+ value } else { - value +: genericArrayData.toArray(elementType) + value +: values } - } - row(index) = new GenericArrayData(result) case dataType => row(index) = convertTo(data, dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala index 0471cb310d89d..a59ea6f460dee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala @@ -38,9 +38,14 @@ object StaxXmlParserUtils { def filteredReader(xml: String): XMLEventReader = { val filter = new EventFilter { override def accept(event: XMLEvent): Boolean = - // Ignore comments and processing instructions event.getEventType match { + // Ignore comments and processing instructions case XMLStreamConstants.COMMENT | XMLStreamConstants.PROCESSING_INSTRUCTION => false + // unsupported events + case XMLStreamConstants.DTD | + XMLStreamConstants.ENTITY_DECLARATION | + XMLStreamConstants.ENTITY_REFERENCE | + XMLStreamConstants.NOTATION_DECLARATION => false case _ => true } } @@ -121,7 +126,10 @@ object StaxXmlParserUtils { /** * Convert the current structure of XML document to a XML string. */ - def currentStructureAsString(parser: XMLEventReader): String = { + def currentStructureAsString( + parser: XMLEventReader, + startElementName: String, + options: XmlOptions): String = { val xmlString = new StringBuilder() var indent = 0 do { @@ -151,6 +159,7 @@ object StaxXmlParserUtils { indent > 0 case _ => true }) + skipNextEndElement(parser, startElementName, options) xmlString.toString() } @@ -178,4 +187,21 @@ object StaxXmlParserUtils { } } } + + @tailrec + def skipNextEndElement( + parser: XMLEventReader, + expectedNextEndElementName: String, + options: XmlOptions): Unit = { + parser.nextEvent() match { + case c: Characters if c.isWhiteSpace => + skipNextEndElement(parser, expectedNextEndElementName, options) + case endElement: EndElement => + assert( + getName(endElement.getName, options) == expectedNextEndElementName, + s"Expected EndElement ") + case _ => throw new IllegalStateException( + s"Expected EndElement ") + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala index 59222f56454fa..51d5ae532b05d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala @@ -23,7 +23,6 @@ import javax.xml.stream.events._ import javax.xml.transform.stream.StreamSource import javax.xml.validation.Schema -import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.control.Exception._ @@ -157,38 +156,17 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) parser.peek match { case _: EndElement => NullType case _: StartElement => inferObject(parser) - case c: Characters if c.isWhiteSpace => - // When `Characters` is found, we need to look further to decide - // if this is really data or space between other elements. - val data = c.getData - parser.nextEvent() - parser.peek match { - case _: StartElement => inferObject(parser) - case _: EndElement if data.isEmpty => NullType - case _: EndElement if options.nullValue == "" => NullType - case _: EndElement => StringType - case _ => inferField(parser) - } - case c: Characters if !c.isWhiteSpace => - val characterType = inferFrom(c.getData) - parser.nextEvent() - parser.peek match { - case _: StartElement => - // Some more elements follow; - // This is a mix of values and other elements - val innerType = inferObject(parser).asInstanceOf[StructType] - addOrUpdateValueTagType(innerType, characterType) - case _ => - val fieldType = inferField(parser) - fieldType match { - case st: StructType => addOrUpdateValueTagType(st, characterType) - case _: NullType => characterType - case _: DataType => - // The field type couldn't be an array type - new StructType() - .add(options.valueTag, addOrUpdateType(Some(characterType), fieldType)) - - } + case _: Characters => + val structType = inferObject(parser).asInstanceOf[StructType] + structType match { + case _ if structType.fields.isEmpty => + NullType + case simpleType + if structType.fields.length == 1 + && isPrimitiveType(structType.fields.head.dataType) + && isValueTagField(structType.fields.head, caseSensitive) => + simpleType.fields.head.dataType + case _ => structType } case e: XMLEvent => throw new IllegalArgumentException(s"Failed to parse data with unexpected event $e") @@ -224,22 +202,6 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) val nameToDataType = collection.mutable.TreeMap.empty[String, DataType](caseSensitivityOrdering) - @tailrec - def inferAndCheckEndElement(parser: XMLEventReader): Boolean = { - parser.peek match { - case _: EndElement | _: EndDocument => true - case _: StartElement => false - case c: Characters if !c.isWhiteSpace => - val characterType = inferFrom(c.getData) - parser.nextEvent() - addOrUpdateType(nameToDataType, options.valueTag, characterType) - inferAndCheckEndElement(parser) - case _ => - parser.nextEvent() - inferAndCheckEndElement(parser) - } - } - // If there are attributes, then we should process them first. val rootValuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options) @@ -253,6 +215,7 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) case e: StartElement => val attributes = e.getAttributes.asScala.toArray val valuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options) + val field = StaxXmlParserUtils.getName(e.asStartElement.getName, options) val inferredType = inferField(parser) match { case st: StructType if valuesMap.nonEmpty => // Merge attributes to the field @@ -267,7 +230,9 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) case dt: DataType if valuesMap.nonEmpty => // We need to manually add the field for value. val nestedBuilder = ArrayBuffer[StructField]() - nestedBuilder += StructField(options.valueTag, dt, nullable = true) + if (!dt.isInstanceOf[NullType]) { + nestedBuilder += StructField(options.valueTag, dt, nullable = true) + } valuesMap.foreach { case (f, v) => nestedBuilder += StructField(f, inferFrom(v), nullable = true) @@ -277,16 +242,15 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) case dt: DataType => dt } // Add the field and datatypes so that we can check if this is ArrayType. - val field = StaxXmlParserUtils.getName(e.asStartElement.getName, options) addOrUpdateType(nameToDataType, field, inferredType) case c: Characters if !c.isWhiteSpace => - // This can be an attribute-only object + // This is a value tag val valueTagType = inferFrom(c.getData) addOrUpdateType(nameToDataType, options.valueTag, valueTagType) - case _: EndElement => - shouldStop = inferAndCheckEndElement(parser) + case _: EndElement | _: EndDocument => + shouldStop = true case _ => // do nothing } @@ -429,56 +393,6 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) case other => Some(other) } - /** - * This helper function merges the data type of value tags and inner elements. - * It could only be structure data. Consider the following case, - * - * value1 - * 1 - * value2 - * - * Input: ''a struct'' and ''_VALUE string'' - * Return: ''a struct>'' - * @param objectType inner elements' type - * @param valueTagType value tag's type - */ - private[xml] def addOrUpdateValueTagType( - objectType: DataType, - valueTagType: DataType): DataType = { - (objectType, valueTagType) match { - case (st: StructType, _) => - val valueTagIndexOpt = st.getFieldIndex(options.valueTag) - - valueTagIndexOpt match { - // If the field name exists in the inner elements, - // merge the type and infer the combined field as an array type if necessary - case Some(index) if !st(index).dataType.isInstanceOf[ArrayType] => - updateStructField( - st, - index, - ArrayType(compatibleType(caseSensitive, options.valueTag)( - st(index).dataType, valueTagType))) - case Some(index) => - updateStructField(st, index, compatibleType(caseSensitive, options.valueTag)( - st(index).dataType, valueTagType)) - case None => - st.add(options.valueTag, valueTagType) - } - case _ => - throw new IllegalStateException( - "illegal state when merging value tags types in schema inference" - ) - } - } - - private def updateStructField( - structType: StructType, - index: Int, - newType: DataType): StructType = { - val newFields: Array[StructField] = - structType.fields.updated(index, structType.fields(index).copy(dataType = newType)) - StructType(newFields) - } private def addOrUpdateType( nameToDataType: collection.mutable.TreeMap[String, DataType], @@ -501,6 +415,23 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) newType } } + + private[xml] def isPrimitiveType(dataType: DataType): Boolean = { + dataType match { + case _: StructType => false + case _: ArrayType => false + case _: MapType => false + case _ => true + } + } + + private[xml] def isValueTagField(structField: StructField, caseSensitive: Boolean): Boolean = { + if (!caseSensitive) { + structField.name.toLowerCase(Locale.ROOT) == options.valueTag.toLowerCase(Locale.ROOT) + } else { + structField.name == options.valueTag + } + } } object XmlInferSchema { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 78f9d5285c239..5fdf949a2137d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -1098,9 +1098,11 @@ class XmlSuite extends QueryTest with SharedSparkSession { assert(valid.toSeq.toArray.take(schema.length - 1) === Array(Row(10, 10), Row(10, "Ten"), 10.0, 10.0, true, "Ten", Array(1, 2), Map("a" -> 123, "b" -> 345))) - assert(invalid.toSeq.toArray.take(schema.length - 1) === - Array(null, null, null, null, null, - "Ten", Array(2), null)) + // TODO: we don't support partial results + assert( + invalid.toSeq.toArray.take(schema.length - 1) === + Array(null, null, null, null, null, + null, null, null)) assert(valid.toSeq.toArray.last === null) assert(invalid.toSeq.toArray.last.toString.contains( @@ -1337,7 +1339,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { .xml(getTestResourcePath(resDir + "whitespace_error.xml")) assert(whitespaceDF.count() === 1) - assert(whitespaceDF.take(1).head.getAs[String]("_corrupt_record") !== null) + assert(whitespaceDF.take(1).head.getAs[String]("_corrupt_record") === null) } test("struct with only attributes and no value tag does not crash") { @@ -2479,7 +2481,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { .xml(input) checkAnswer(df, Seq( - Row("\" \"", Row(1, "\" \""), Row(Row(null, " "))))) + Row("\" \"", Row("\" \"", 1), Row(Row(" "))))) } test("capture values interspersed between elements - nested comments") { @@ -2552,7 +2554,9 @@ class XmlSuite extends QueryTest with SharedSparkSession { | value4 | | value5 - | 1 + | 1 + | text.]]> + | text.]]> | value6 | 2 | value7 @@ -2563,10 +2567,10 @@ class XmlSuite extends QueryTest with SharedSparkSession { | | value10 | - | + | | 3 | value11 - | 4 + | 4 | | string | value12 @@ -2577,7 +2581,9 @@ class XmlSuite extends QueryTest with SharedSparkSession { | | value15 | + | | value16 + | | |""".stripMargin val input = spark.createDataset(Seq(xmlString)) @@ -2594,14 +2600,22 @@ class XmlSuite extends QueryTest with SharedSparkSession { Row( ArraySeq("value3", "value10", "value13", "value14"), Array( - Row( - ArraySeq("value4", "value8", "value9"), - "string", - Row(ArraySeq("value5", "value6", "value7"), ArraySeq(1, 2))), - Row( - ArraySeq("value12"), - "string", - Row(ArraySeq("value11"), ArraySeq(3, 4)))), + Row( + ArraySeq("value4", "value8", "value9"), + "string", + Row( + ArraySeq( + "value5", + "This is a CDATA section containing text." + + "\n This is a CDATA section containing text.\n" + + " value6", + "value7" + ), + ArraySeq(1, 2) + ) + ), + Row(ArraySeq("value12"), "string", Row(ArraySeq("value11"), ArraySeq(3, 4))) + ), 3)))) checkAnswer(df, expectedAns) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlParserUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlParserUtilsSuite.scala index 13a90acb71526..a4ac25b036c41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlParserUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlParserUtilsSuite.scala @@ -63,7 +63,7 @@ final class StaxXmlParserUtilsSuite extends SparkFunSuite with BeforeAndAfterAll val parser = factory.createXMLEventReader(new StringReader(input.toString)) // Skip until StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.END_ELEMENT) - val xmlString = StaxXmlParserUtils.currentStructureAsString(parser) + val xmlString = StaxXmlParserUtils.currentStructureAsString(parser, "ROW", new XmlOptions()) val expected = Sam Mad Dog Smith19 assert(xmlString === expected.toString())