diff --git a/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala b/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala index ca34eb49..9560a336 100644 --- a/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala +++ b/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala @@ -15,7 +15,7 @@ */ package com.databricks.spark.xml -import java.io.{InputStream, IOException} +import java.io.{IOException, InputStream} import java.nio.charset.Charset import org.apache.hadoop.conf.Configuration @@ -25,14 +25,17 @@ import org.apache.hadoop.io.{DataOutputBuffer, LongWritable, Text} import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{FileSplit, TextInputFormat} +import scala.collection.mutable.ArrayBuffer + /** * Reads records that are delimited by a specific start/end tag. */ class XmlInputFormat extends TextInputFormat { override def createRecordReader( - split: InputSplit, - context: TaskAttemptContext): RecordReader[LongWritable, Text] = { + split: InputSplit, + context: TaskAttemptContext): + RecordReader[LongWritable, Text] = { new XmlRecordReader } } @@ -47,9 +50,9 @@ object XmlInputFormat { } /** - * XMLRecordReader class to read through a given xml document to output xml blocks as records - * as specified by the start tag and end tag - */ + * XMLRecordReader class to read through a given xml document to output xml blocks as records + * as specified by the start tag and end tag + */ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { private var startTag: Array[Byte] = _ private var currentStartTag: Array[Byte] = _ @@ -111,7 +114,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { // So we have a split that is only part of a file stored using // a Compression codec that cannot be split. throw new IOException("Cannot seek in " + - codec.getClass.getSimpleName + " compressed stream") + codec.getClass.getSimpleName + " compressed stream") } val cIn = c.createInputStream(fsin, decompressor) in = cIn @@ -131,13 +134,13 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { } /** - * Finds the start of the next record. - * It treats data from `startTag` and `endTag` as a record. - * - * @param key the current key that will be written - * @param value the object that will be written - * @return whether it reads successfully - */ + * Finds the start of the next record. + * It treats data from `startTag` and `endTag` as a record. + * + * @param key the current key that will be written + * @param value the object that will be written + * @return whether it reads successfully + */ private def next(key: LongWritable, value: Text): Boolean = { if (readUntilStartElement()) { try { @@ -189,9 +192,24 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { false } - private def checkEmptyTag(currentLetter: Int, position: Int): Boolean = { + private def checkEmptyTag(currentLetter: Int, position: Int, + buffer: DataOutputBuffer): Boolean = { + def checkStartTagBefore = { + val startAngleInByte = '<'.toByte + val spaceInByte = ' '.toByte + val rootTagName = buffer.getData + .reverse + .takeWhile(_ != startAngleInByte) + .reverse + .takeWhile(_ != spaceInByte) + val result = startAngleInByte +: rootTagName + + result.sameElements(startTag.dropRight(1)) + } + if (position >= endEmptyTag.length) false - else currentLetter == endEmptyTag(position) + else currentLetter == endEmptyTag(position) && + checkStartTagBefore } private def readUntilEndElement(): Boolean = { @@ -207,7 +225,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { } else { buffer.write(rb) val b = rb.toByte - if (b == startTag(si) && (b == endTag(ei) || checkEmptyTag(b, ei))) { + if (b == startTag(si) && (b == endTag(ei) || checkEmptyTag(b, ei, buffer))) { // In start tag or end tag. si += 1 ei += 1 @@ -222,9 +240,9 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { si += 1 ei = 0 } - } else if (b == endTag(ei) || checkEmptyTag(b, ei)) { + } else if (b == endTag(ei) || checkEmptyTag(b, ei, buffer)) { if ((b == endTag(ei) && ei >= endTag.length - 1) || - (checkEmptyTag(b, ei) && ei >= endEmptyTag.length - 1)) { + (checkEmptyTag(b, ei, buffer) && ei >= endEmptyTag.length - 1)) { if (depth == 0) { // Found closing end tag. return true @@ -253,7 +271,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { private def checkAttributes(current: Int): Boolean = { var len = 0 var b = current - while(len < space.length && b == space(len)) { + while (len < space.length && b == space(len)) { len += 1 if (len >= space.length) { currentStartTag = startTag.take(startTag.length - angleBracket.length) ++ space diff --git a/src/test/resources/self-closing-tag.xml b/src/test/resources/self-closing-tag.xml new file mode 100644 index 00000000..c3057b22 --- /dev/null +++ b/src/test/resources/self-closing-tag.xml @@ -0,0 +1,6 @@ + + + 1 + + + diff --git a/src/test/scala/com/databricks/spark/xml/XmlSuite.scala b/src/test/scala/com/databricks/spark/xml/XmlSuite.scala index 3dd56d3e..62262597 100755 --- a/src/test/scala/com/databricks/spark/xml/XmlSuite.scala +++ b/src/test/scala/com/databricks/spark/xml/XmlSuite.scala @@ -60,6 +60,8 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { val nestedElementWithNameOfParent = "src/test/resources/nested-element-with-name-of-parent.xml" val booksMalformedAttributes = "src/test/resources/books-malformed-attributes.xml" val fiasHouse = "src/test/resources/fias_house.xml" + val fiasHouseSimple = "src/test/resources/fias_house_simple.xml" + val selfClosingTag = "src/test/resources/self-closing-tag.xml" val booksTag = "book" val booksRootTag = "books" @@ -904,7 +906,7 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { assert(results(1)(0) === "bk112") } - test("read utf-8 encoded file with empty tag") { + test("empty tag data only in attributes") { val df = spark.read.format("xml") .option("excludeAttribute", "false") .option("rowTag", fiasRowTag) @@ -913,4 +915,18 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { assert(df.collect().length == numFiasHouses) assert(df.select().where("_HOUSEID is null").count() == 0) } + + test("Produces correct result for a row with a self closing tag inside") { + val schema = StructType(Seq( + StructField("non-empty-tag", IntegerType, nullable = true), + StructField("self-closing-tag", IntegerType, nullable = true) + )) + + val result = new XmlReader() + .withSchema(schema) + .xmlFile(spark, selfClosingTag) + .collect() + + assert(result(0).toSeq === Seq(1, null)) + } }