diff --git a/.travis.yml b/.travis.yml index ba5b5d44..ba009c73 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,37 +5,13 @@ cache: - $HOME/.ivy2 matrix: include: - # Spark 1.3.0 - - jdk: openjdk6 - scala: 2.10.5 - env: TEST_HADOOP_VERSION="1.2.1" TEST_SPARK_VERSION="1.3.0" - - jdk: openjdk6 - scala: 2.11.7 - env: TEST_HADOOP_VERSION="1.0.4" TEST_SPARK_VERSION="1.3.0" - # Spark 1.4.1 - # We only test Spark 1.4.1 with Hadooop 2.2.0 because - # https://github.com/apache/spark/pull/6599 is not present in 1.4.1, - # so the published Spark Maven artifacts will not work with Hadoop 1.x. - - jdk: openjdk6 - scala: 2.10.5 - env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="1.4.1" - - jdk: openjdk7 - scala: 2.11.7 - env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="1.4.1" - # Spark 1.5.0 - - jdk: openjdk7 - scala: 2.10.5 - env: TEST_HADOOP_VERSION="1.0.4" TEST_SPARK_VERSION="1.5.0" - - jdk: openjdk7 - scala: 2.11.7 - env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="1.5.0" - # Spark 1.6.0 + # Spark 2.0.0 - jdk: openjdk7 scala: 2.10.5 - env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="1.6.0" + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.0.0" - jdk: openjdk7 scala: 2.11.7 - env: TEST_HADOOP_VERSION="1.2.1" TEST_SPARK_VERSION="1.6.0" + env: TEST_HADOOP_VERSION="2.6.0" TEST_SPARK_VERSION="2.0.0" script: - sbt -Dhadoop.testVersion=$TEST_HADOOP_VERSION -Dspark.testVersion=$TEST_SPARK_VERSION ++$TRAVIS_SCALA_VERSION coverage test - sbt ++$TRAVIS_SCALA_VERSION assembly diff --git a/README.md b/README.md index c53c654d..2ec1d6aa 100644 --- a/README.md +++ b/README.md @@ -496,4 +496,3 @@ This library is built with [SBT](http://www.scala-sbt.org/0.13/docs/Command-Line ## Acknowledgements This project was initially created by [HyukjinKwon](https://github.com/HyukjinKwon) and donated to [Databricks](https://databricks.com). - diff --git a/build.sbt b/build.sbt index ca008e3f..5a8c5e51 100755 --- a/build.sbt +++ b/build.sbt @@ -1,6 +1,6 @@ name := "spark-xml" -version := "0.3.4" +version := "0.4.0-SNAPSHOT" organization := "com.databricks" @@ -10,7 +10,7 @@ spName := "databricks/spark-xml" crossScalaVersions := Seq("2.10.5", "2.11.7") -sparkVersion := "1.6.0" +sparkVersion := "2.0.0" val testSparkVersion = settingKey[String]("The version of Spark to test against.") @@ -75,26 +75,3 @@ ScoverageSbtPlugin.ScoverageKeys.coverageHighlighting := { if (scalaBinaryVersion.value == "2.10") false else true } - -// -- MiMa binary compatibility checks ------------------------------------------------------------ - -//import com.typesafe.tools.mima.core._ -//import com.typesafe.tools.mima.plugin.MimaKeys.binaryIssueFilters -//import com.typesafe.tools.mima.plugin.MimaKeys.previousArtifact -//import com.typesafe.tools.mima.plugin.MimaPlugin.mimaDefaultSettings -// -//mimaDefaultSettings ++ Seq( -// previousArtifact := Some("org.apache" %% "spark-xml" % "1.2.0"), -// binaryIssueFilters ++= Seq( -// // These classes are not intended to be public interfaces: -// ProblemFilters.excludePackage("org.apache.spark.xml.XmlRelation"), -// ProblemFilters.excludePackage("org.apache.spark.xml.util.InferSchema"), -// ProblemFilters.excludePackage("org.apache.spark.sql.readers"), -// ProblemFilters.excludePackage("org.apache.spark.xml.util.TypeCast"), -// // We allowed the private `XmlRelation` type to leak into the public method signature: -// ProblemFilters.exclude[IncompatibleResultTypeProblem]( -// "org.apache.spark.xml.DefaultSource.createRelation") -// ) -//) - -// ------------------------------------------------------------------------------------------------ diff --git a/src/main/scala/com/databricks/spark/xml/DefaultSource.scala b/src/main/scala/com/databricks/spark/xml/DefaultSource.scala index 796ca2dc..49fe25c0 100755 --- a/src/main/scala/com/databricks/spark/xml/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/xml/DefaultSource.scala @@ -20,7 +20,6 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} -import com.databricks.spark.xml.util.CompressionCodecs import com.databricks.spark.xml.util.XmlFile /** @@ -90,9 +89,7 @@ class DefaultSource } if (doSave) { // Only save data when the save mode is not ignore. - val codecClass = - CompressionCodecs.getCodecClass(XmlOptions(parameters).codec) - data.saveAsXmlFile(filesystemPath.toString, parameters, codecClass) + XmlFile.saveAsXmlFile(data, filesystemPath.toString, parameters) } createRelation(sqlContext, parameters, data.schema) } diff --git a/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala b/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala index c124b11e..66bd7272 100644 --- a/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala +++ b/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala @@ -70,12 +70,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = { val fileSplit: FileSplit = split.asInstanceOf[FileSplit] - val conf: Configuration = { - // Use reflection to get the Configuration. This is necessary because TaskAttemptContext is - // a class in Hadoop 1.x and an interface in Hadoop 2.x. - val method = context.getClass.getMethod("getConfiguration") - method.invoke(context).asInstanceOf[Configuration] - } + val conf: Configuration = context.getConfiguration val charset = Charset.forName(conf.get(XmlInputFormat.ENCODING_KEY, XmlOptions.DEFAULT_CHARSET)) startTag = conf.get(XmlInputFormat.START_TAG_KEY).getBytes(charset) @@ -97,25 +92,18 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { val codec = new CompressionCodecFactory(conf).getCodec(path) if (null != codec) { decompressor = CodecPool.getDecompressor(codec) - // Use reflection to get the splittable compression codec and stream. This is necessary - // because SplittableCompressionCodec does not exist in Hadoop 1.0.x. - def isSplitCompressionCodec(obj: Any) = { - val splittableClassName = "org.apache.hadoop.io.compress.SplittableCompressionCodec" - obj.getClass.getInterfaces.map(_.getName).contains(splittableClassName) - } - // Here I made separate variables to avoid to try to find SplitCompressionInputStream at - // runtime. - val (inputStream, seekable) = codec match { - case c: CompressionCodec if isSplitCompressionCodec(c) => - // At Hadoop 1.0.x, this case would not be executed. - val cIn = { - val sc = c.asInstanceOf[SplittableCompressionCodec] - sc.createInputStream(fsin, decompressor, start, - end, SplittableCompressionCodec.READ_MODE.BYBLOCK) - } + codec match { + case sc: SplittableCompressionCodec => + val cIn = sc.createInputStream( + fsin, + decompressor, + start, + end, + SplittableCompressionCodec.READ_MODE.BYBLOCK) start = cIn.getAdjustedStart end = cIn.getAdjustedEnd - (cIn, cIn) + in = cIn + filePosition = cIn case c: CompressionCodec => if (start != 0) { // So we have a split that is only part of a file stored using @@ -124,10 +112,9 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { codec.getClass.getSimpleName + " compressed stream") } val cIn = c.createInputStream(fsin, decompressor) - (cIn, fsin) + in = cIn + filePosition = fsin } - in = inputStream - filePosition = seekable } else { in = fsin filePosition = fsin diff --git a/src/main/scala/com/databricks/spark/xml/XmlRelation.scala b/src/main/scala/com/databricks/spark/xml/XmlRelation.scala index 129f5b38..f72e5150 100755 --- a/src/main/scala/com/databricks/spark/xml/XmlRelation.scala +++ b/src/main/scala/com/databricks/spark/xml/XmlRelation.scala @@ -24,7 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.sources.{PrunedScan, InsertableRelation, BaseRelation, TableScan} import org.apache.spark.sql.types._ -import com.databricks.spark.xml.util.{CompressionCodecs, InferSchema} +import com.databricks.spark.xml.util.{InferSchema, XmlFile} import com.databricks.spark.xml.parsers.StaxXmlParser case class XmlRelation protected[spark] ( @@ -90,8 +90,7 @@ case class XmlRelation protected[spark] ( + s" to INSERT OVERWRITE a XML table:\n${e.toString}") } // Write the data. We assume that schema isn't changed, and we won't update it. - val codecClass = CompressionCodecs.getCodecClass(options.codec) - data.saveAsXmlFile(filesystemPath.toString, parameters, codecClass) + XmlFile.saveAsXmlFile(data, filesystemPath.toString, parameters) } else { sys.error("XML tables only support INSERT OVERWRITE for now.") } diff --git a/src/main/scala/com/databricks/spark/xml/package.scala b/src/main/scala/com/databricks/spark/xml/package.scala index 5eae2afd..b0b10994 100755 --- a/src/main/scala/com/databricks/spark/xml/package.scala +++ b/src/main/scala/com/databricks/spark/xml/package.scala @@ -15,23 +15,19 @@ */ package com.databricks.spark -import java.io.CharArrayWriter -import javax.xml.stream.XMLOutputFactory - import scala.collection.Map -import com.sun.xml.internal.txw2.output.IndentingXMLStreamWriter import org.apache.hadoop.io.compress.CompressionCodec import org.apache.spark.sql.{DataFrame, SQLContext} import com.databricks.spark.xml.util.XmlFile -import com.databricks.spark.xml.parsers.StaxXmlGenerator package object xml { /** * Adds a method, `xmlFile`, to [[SQLContext]] that allows reading XML data. */ implicit class XmlContext(sqlContext: SQLContext) extends Serializable { + @deprecated("Use DataFrameReader.read()", "0.4.0") def xmlFile( filePath: String, rowTag: String = XmlOptions.DEFAULT_ROW_TAG, @@ -82,69 +78,16 @@ package object xml { // // // Namely, roundtrip in writing and reading can end up in different schema structure. + @deprecated("Use DataFrameWriter.write()", "0.4.0") def saveAsXmlFile( path: String, parameters: Map[String, String] = Map(), compressionCodec: Class[_ <: CompressionCodec] = null): Unit = { - val options = XmlOptions(parameters.toMap) - val startElement = s"<${options.rootTag}>" - val endElement = s"" - val rowSchema = dataFrame.schema - val indent = XmlFile.DEFAULT_INDENT - val rowSeparator = XmlFile.DEFAULT_ROW_SEPARATOR - - val xmlRDD = dataFrame.rdd.mapPartitions { iter => - val factory = XMLOutputFactory.newInstance() - val writer = new CharArrayWriter() - val xmlWriter = factory.createXMLStreamWriter(writer) - val indentingXmlWriter = new IndentingXMLStreamWriter(xmlWriter) - indentingXmlWriter.setIndentStep(indent) - - new Iterator[String] { - var firstRow: Boolean = true - var lastRow: Boolean = true - - override def hasNext: Boolean = iter.hasNext || firstRow || lastRow - - override def next: String = { - if (iter.nonEmpty) { - val xml = { - StaxXmlGenerator( - rowSchema, - indentingXmlWriter, - options)(iter.next()) - writer.toString - } - writer.reset() - - // Here it needs to add indentations for the start of each line, - // in order to insert the start element and end element. - val indentedXml = indent + xml.replaceAll(rowSeparator, rowSeparator + indent) - if (firstRow) { - firstRow = false - startElement + rowSeparator + indentedXml - } else { - indentedXml - } - } else { - indentingXmlWriter.close() - if (!firstRow) { - lastRow = false - endElement - } else { - // This means the iterator was initially empty. - firstRow = false - lastRow = false - "" - } - } - } - } - } - - compressionCodec match { - case null => xmlRDD.saveAsTextFile(path) - case codec => xmlRDD.saveAsTextFile(path, codec) - } + val mutableParams = collection.mutable.Map(parameters.toSeq: _*) + val safeCodec = mutableParams.get("codec") + .orElse(Option(compressionCodec).map(_.getCanonicalName)) + .orNull + mutableParams.put("codec", safeCodec) + XmlFile.saveAsXmlFile(dataFrame, path, mutableParams.toMap) } } } diff --git a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlGenerator.scala b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlGenerator.scala index 8d14ff94..cfe08024 100644 --- a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlGenerator.scala +++ b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlGenerator.scala @@ -85,7 +85,6 @@ private[xml] object StaxXmlGenerator { case (ByteType, v: Byte) => writer.writeCharacters(v.toString) case (BooleanType, v: Boolean) => writer.writeCharacters(v.toString) case (DateType, v) => writer.writeCharacters(v.toString) - case (udt: UserDefinedType[_], v) => writeElement(udt.sqlType, udt.serialize(v)) // For the case roundtrip in reading and writing XML files, [[ArrayType]] cannot have // [[ArrayType]] as element type. It always wraps the element with [[StructType]]. So, diff --git a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala index 404ac0fd..88544977 100644 --- a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala +++ b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala @@ -106,7 +106,6 @@ private[xml] object StaxXmlParser { case dt: StructType => convertObject(parser, dt, options) case MapType(StringType, vt, _) => convertMap(parser, vt, options) case ArrayType(st, _) => convertField(parser, st, options) - case udt: UserDefinedType[_] => convertField(parser, udt.sqlType, options) case _: StringType => StaxXmlParserUtils.currentStructureAsString(parser) } @@ -153,7 +152,7 @@ private[xml] object StaxXmlParser { case (v, ByteType) => castTo(v, ByteType) case (v, ShortType) => castTo(v, ShortType) case (v, IntegerType) => signSafeToInt(v) - case (v, _: DecimalType) => castTo(v, new DecimalType(None)) + case (v, dt: DecimalType) => castTo(v, dt) case (_, dataType) => sys.error(s"Failed to parse a value for data type $dataType.") } diff --git a/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala b/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala index ad4471b6..a3649a37 100644 --- a/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala +++ b/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala @@ -36,7 +36,7 @@ private[xml] object InferSchema { /** * Copied from internal Spark api - * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] + * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion]] */ private val numericPrecedence: IndexedSeq[DataType] = IndexedSeq[DataType]( @@ -47,7 +47,7 @@ private[xml] object InferSchema { FloatType, DoubleType, TimestampType, - DecimalType.Unlimited) + DecimalType.SYSTEM_DEFAULT) val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) diff --git a/src/main/scala/com/databricks/spark/xml/util/XmlFile.scala b/src/main/scala/com/databricks/spark/xml/util/XmlFile.scala index 5abb08b2..155f7edd 100644 --- a/src/main/scala/com/databricks/spark/xml/util/XmlFile.scala +++ b/src/main/scala/com/databricks/spark/xml/util/XmlFile.scala @@ -15,12 +15,19 @@ */ package com.databricks.spark.xml.util +import java.io.CharArrayWriter import java.nio.charset.Charset +import javax.xml.stream.XMLOutputFactory +import scala.collection.Map + +import com.databricks.spark.xml.parsers.StaxXmlGenerator +import com.sun.xml.internal.txw2.output.IndentingXMLStreamWriter import org.apache.hadoop.io.{Text, LongWritable} -import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.sql.DataFrame import com.databricks.spark.xml.{XmlOptions, XmlInputFormat} private[xml] object XmlFile { @@ -47,4 +54,90 @@ private[xml] object XmlFile { classOf[Text]).map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset)) } } + + /** + * Note that writing a XML file from [[DataFrame]] having a field + * [[org.apache.spark.sql.types.ArrayType]] with its element as nested array would have + * an additional nested field for the element. For example, the [[DataFrame]] having + * a field below, + * + * fieldA Array(Array(data1, data2)) + * + * would produce a XML file below. + * + * + * data1 + * + * + * data2 + * + * + * Namely, roundtrip in writing and reading can end up in different schema structure. + */ + def saveAsXmlFile( + dataFrame: DataFrame, + path: String, + parameters: Map[String, String] = Map()): Unit = { + val options = XmlOptions(parameters.toMap) + val codecClass = CompressionCodecs.getCodecClass(options.codec) + val startElement = s"<${options.rootTag}>" + val endElement = s"" + val rowSchema = dataFrame.schema + val indent = XmlFile.DEFAULT_INDENT + val rowSeparator = XmlFile.DEFAULT_ROW_SEPARATOR + + val xmlRDD = dataFrame.rdd.mapPartitions { iter => + val factory = XMLOutputFactory.newInstance() + val writer = new CharArrayWriter() + val xmlWriter = factory.createXMLStreamWriter(writer) + val indentingXmlWriter = new IndentingXMLStreamWriter(xmlWriter) + indentingXmlWriter.setIndentStep(indent) + + new Iterator[String] { + var firstRow: Boolean = true + var lastRow: Boolean = true + + override def hasNext: Boolean = iter.hasNext || firstRow || lastRow + + override def next: String = { + if (iter.nonEmpty) { + val xml = { + StaxXmlGenerator( + rowSchema, + indentingXmlWriter, + options)(iter.next()) + writer.toString + } + writer.reset() + + // Here it needs to add indentations for the start of each line, + // in order to insert the start element and end element. + val indentedXml = indent + xml.replaceAll(rowSeparator, rowSeparator + indent) + if (firstRow) { + firstRow = false + startElement + rowSeparator + indentedXml + } else { + indentedXml + } + } else { + indentingXmlWriter.close() + if (!firstRow) { + lastRow = false + endElement + } else { + // This means the iterator was initially empty. + firstRow = false + lastRow = false + "" + } + } + } + } + } + + codecClass match { + case null => xmlRDD.saveAsTextFile(path) + case codec => xmlRDD.saveAsTextFile(path, codec) + } + } } diff --git a/src/test/java/com/databricks/spark/xml/JavaXmlSuite.java b/src/test/java/com/databricks/spark/xml/JavaXmlSuite.java index 68a6abb1..2a105cb8 100644 --- a/src/test/java/com/databricks/spark/xml/JavaXmlSuite.java +++ b/src/test/java/com/databricks/spark/xml/JavaXmlSuite.java @@ -49,9 +49,9 @@ public void tearDown() { @Test public void testXmlParser() { - DataFrame df = (new XmlReader()).withRowTag(booksFileTag).xmlFile(sqlContext, booksFile); + Dataset df = (new XmlReader()).withRowTag(booksFileTag).xmlFile(sqlContext, booksFile); String prefix = XmlOptions.DEFAULT_ATTRIBUTE_PREFIX(); - int result = df.select(prefix + "id").collect().length; + long result = df.select(prefix + "id").count(); Assert.assertEquals(result, numBooks); } @@ -61,19 +61,19 @@ public void testLoad() { options.put("rowTag", booksFileTag); options.put("path", booksFile); - DataFrame df = sqlContext.load("com.databricks.spark.xml", options); - int result = df.select("description").collect().length; + Dataset df = sqlContext.load("com.databricks.spark.xml", options); + long result = df.select("description").count(); Assert.assertEquals(result, numBooks); } @Test public void testSave() { - DataFrame df = (new XmlReader()).withRowTag(booksFileTag).xmlFile(sqlContext, booksFile); + Dataset df = (new XmlReader()).withRowTag(booksFileTag).xmlFile(sqlContext, booksFile); TestUtils.deleteRecursively(new File(tempDir)); - df.select("price", "description").save(tempDir, "com.databricks.spark.xml"); + df.select("price", "description").write().format("xml").save(tempDir); - DataFrame newDf = (new XmlReader()).xmlFile(sqlContext, tempDir); - int result = newDf.select("price").collect().length; + Dataset newDf = (new XmlReader()).xmlFile(sqlContext, tempDir); + long result = newDf.select("price").count(); Assert.assertEquals(result, numBooks); } } diff --git a/src/test/scala/com/databricks/spark/xml/XmlSuite.scala b/src/test/scala/com/databricks/spark/xml/XmlSuite.scala index 623533ad..d1229c83 100755 --- a/src/test/scala/com/databricks/spark/xml/XmlSuite.scala +++ b/src/test/scala/com/databricks/spark/xml/XmlSuite.scala @@ -85,8 +85,8 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { } test("DSL test") { - val results = sqlContext - .xmlFile(carsFile) + val results = sqlContext.read.format("xml") + .load(carsFile) .select("year") .collect() @@ -94,15 +94,16 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { } test("DSL test with xml having unbalanced datatypes") { - val results = sqlContext - .xmlFile(gpsEmptyField, treatEmptyValuesAsNulls = true) + val results = sqlContext.read.format("xml") + .option("treatEmptyValuesAsNulls", "true") + .load(gpsEmptyField) assert(results.collect().size === numGPS) } test("DSL test with mixed elements (attributes, no child)") { - val results = sqlContext - .xmlFile(carsMixedAttrNoChildFile) + val results = sqlContext.read.format("xml") + .load(carsMixedAttrNoChildFile) .select("date") .collect() @@ -114,8 +115,9 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { } test("DSL test for inconsistent element attributes as fields") { - val results = sqlContext - .xmlFile(booksAttributesInNoChild, rowTag = booksTag) + val results = sqlContext.read.format("xml") + .option("rowTag", booksTag) + .load(booksAttributesInNoChild) .select("price") // This should not throw an exception `java.lang.ArrayIndexOutOfBoundsException` @@ -130,13 +132,18 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { } test("DSL test with mixed elements (struct, string)") { - val results = sqlContext - .xmlFile(agesMixedTypes, rowTag = agesTag).collect() + val results = sqlContext.read.format("xml") + .option("rowTag", agesTag) + .load(agesMixedTypes) + .collect() assert(results.size === numAges) } test("DSL test with elements in array having attributes") { - val results = sqlContext.xmlFile(agesFile, rowTag = agesTag).collect() + val results = sqlContext.read.format("xml") + .option("rowTag", agesTag) + .load(agesFile) + .collect() val attrValOne = results(0).get(0).asInstanceOf[Row](1) val attrValTwo = results(1).get(0).asInstanceOf[Row](1) assert(attrValOne == "1990-02-24") @@ -158,8 +165,8 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { } test("DSL test compressed file") { - val results = sqlContext - .xmlFile(carsFileGzip) + val results = sqlContext.read.format("xml") + .load(carsFileGzip) .select("year") .collect() @@ -167,8 +174,8 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { } test("DSL test splittable compressed file") { - val results = sqlContext - .xmlFile(carsFileBzip2) + val results = sqlContext.read.format("xml") + .load(carsFileBzip2) .select("year") .collect() @@ -177,8 +184,9 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { test("DSL test bad charset name") { val exception = intercept[UnsupportedCharsetException] { - val results = sqlContext - .xmlFile(carsFile, charset = "1-9588-osi") + val results = sqlContext.read.format("xml") + .option("charset", "1-9588-osi") + .load(carsFile) .select("year") .collect() } @@ -333,14 +341,17 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { new File(tempEmptyDir).mkdirs() val copyFilePath = tempEmptyDir + "cars-copy.xml" - val cars = sqlContext.xmlFile(carsFile) - cars.save("com.databricks.spark.xml", SaveMode.Overwrite, - Map("path" -> copyFilePath, "codec" -> classOf[GzipCodec].getName)) + val cars = sqlContext.read.format("xml").load(carsFile) + cars.write + .format("xml") + .mode(SaveMode.Overwrite) + .options(Map("path" -> copyFilePath, "codec" -> classOf[GzipCodec].getName)) + .save(copyFilePath) val carsCopyPartFile = new File(copyFilePath, "part-00000.gz") // Check that the part file has a .gz extension assert(carsCopyPartFile.exists()) - val carsCopy = sqlContext.xmlFile(copyFilePath + "/") + val carsCopy = sqlContext.read.format("xml").load(copyFilePath) assert(carsCopy.count == cars.count) assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet) @@ -352,14 +363,18 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { new File(tempEmptyDir).mkdirs() val copyFilePath = tempEmptyDir + "cars-copy.xml" - val cars = sqlContext.xmlFile(carsFile) - cars.save("com.databricks.spark.xml", SaveMode.Overwrite, - Map("path" -> copyFilePath, "compression" -> "gZiP")) + val cars = sqlContext.read.format("xml").load(carsFile) + cars.write + .format("xml") + .mode(SaveMode.Overwrite) + .options(Map("path" -> copyFilePath, "compression" -> "gZiP")) + .save(copyFilePath) + val carsCopyPartFile = new File(copyFilePath, "part-00000.gz") // Check that the part file has a .gz extension assert(carsCopyPartFile.exists()) - val carsCopy = sqlContext.xmlFile(copyFilePath) + val carsCopy = sqlContext.read.format("xml").load(copyFilePath) assert(carsCopy.count == cars.count) assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet) @@ -371,11 +386,17 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { new File(tempEmptyDir).mkdirs() val copyFilePath = tempEmptyDir + "books-copy.xml" - val books = sqlContext.xmlFile(booksComplicatedFile, rowTag = booksTag) - books.saveAsXmlFile(copyFilePath, - Map("rootTag" -> booksRootTag, "rowTag" -> booksTag)) - - val booksCopy = sqlContext.xmlFile(copyFilePath + "/", rowTag = booksTag) + val books = sqlContext.read.format("xml") + .option("rowTag", booksTag) + .load(booksComplicatedFile) + books.write + .options(Map("rootTag" -> booksRootTag, "rowTag" -> booksTag)) + .format("xml") + .save(copyFilePath) + + val booksCopy = sqlContext.read.format("xml") + .option("rowTag", booksTag) + .load(copyFilePath) assert(booksCopy.count == books.count) assert(booksCopy.collect.map(_.toString).toSet === books.collect.map(_.toString).toSet) } @@ -386,12 +407,18 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { new File(tempEmptyDir).mkdirs() val copyFilePath = tempEmptyDir + "books-copy.xml" - val books = sqlContext.xmlFile(booksComplicatedFile, rowTag = booksTag) - books.saveAsXmlFile(copyFilePath, - Map("rootTag" -> booksRootTag, "rowTag" -> booksTag, "nullValue" -> "")) + val books = sqlContext.read.format("xml") + .option("rowTag", booksTag) + .load(booksComplicatedFile) + books.write + .options(Map("rootTag" -> booksRootTag, "rowTag" -> booksTag, "nullValue" -> "")) + .format("xml") + .save(copyFilePath) - val booksCopy = - sqlContext.xmlFile(copyFilePath, rowTag = booksTag, treatEmptyValuesAsNulls = true) + val booksCopy = sqlContext.read.format("xml") + .option("rowTag", booksTag) + .option("treatEmptyValuesAsNulls", "true") + .load(copyFilePath) assert(booksCopy.count == books.count) assert(booksCopy.collect.map(_.toString).toSet === books.collect.map(_.toString).toSet) @@ -408,7 +435,7 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { val data = sqlContext.sparkContext.parallelize( List(List(List("aa", "bb"), List("aa", "bb")))).map(Row(_)) val df = sqlContext.createDataFrame(data, schema) - df.saveAsXmlFile(copyFilePath) + df.write.format("xml").save(copyFilePath) // When [[ArrayType]] has [[ArrayType]] as elements, it is confusing what is the element // name for XML file. Now, it is "item". So, "item" field is additionally added @@ -416,8 +443,8 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { val schemaCopy = StructType( List(StructField("a", ArrayType( StructType(List(StructField("item", ArrayType(StringType), nullable = true)))), - nullable = true))) - val dfCopy = sqlContext.xmlFile(copyFilePath + "/") + nullable = true))) + val dfCopy = sqlContext.read.format("xml").load(copyFilePath) assert(dfCopy.count == df.count) assert(dfCopy.schema === schemaCopy) @@ -453,19 +480,18 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { val data = sqlContext.sparkContext.parallelize(Seq(row)) val df = sqlContext.createDataFrame(data, schema) - df.saveAsXmlFile(copyFilePath) + df.write.format("xml").save(copyFilePath) val dfCopy = new XmlReader() .withSchema(schema) - .xmlFile(sqlContext, copyFilePath + "/") + .xmlFile(sqlContext, copyFilePath) assert(dfCopy.collect() === df.collect()) assert(dfCopy.schema === df.schema) } test("DSL test schema inferred correctly") { - val results = sqlContext - .xmlFile(booksFile, rowTag = booksTag) + val results = sqlContext.read.format("xml").option("rowTag", booksTag).load(booksFile) assert(results.schema == StructType(List( StructField(s"${DEFAULT_ATTRIBUTE_PREFIX}id", StringType, nullable = true), @@ -481,8 +507,10 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { } test("DSL test schema inferred correctly with sampling ratio") { - val results = sqlContext - .xmlFile(booksFile, rowTag = booksTag, samplingRatio = 0.5) + val results = sqlContext.read.format("xml") + .option("rowTag", booksTag) + .option("samplingRatio", 0.5) + .load(booksFile) assert(results.schema == StructType(List( StructField(s"${DEFAULT_ATTRIBUTE_PREFIX}id", StringType, nullable = true), @@ -498,8 +526,9 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { } test("DSL test schema (object) inferred correctly") { - val results = sqlContext - .xmlFile(booksNestedObjectFile, rowTag = booksTag) + val results = sqlContext.read.format("xml") + .option("rowTag", booksTag) + .load(booksNestedObjectFile) assert(results.schema == StructType(List( StructField(s"${DEFAULT_ATTRIBUTE_PREFIX}id", StringType, nullable = true), @@ -516,8 +545,9 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { } test("DSL test schema (array) inferred correctly") { - val results = sqlContext - .xmlFile(booksNestedArrayFile, rowTag = booksTag) + val results = sqlContext.read.format("xml") + .option("rowTag", booksTag) + .load(booksNestedArrayFile) assert(results.schema == StructType(List( StructField(s"${DEFAULT_ATTRIBUTE_PREFIX}id", StringType, nullable = true), @@ -533,8 +563,9 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { } test("DSL test schema (complicated) inferred correctly") { - val results = sqlContext - .xmlFile(booksComplicatedFile, rowTag = booksTag) + val results = sqlContext.read.format("xml") + .option("rowTag", booksTag) + .load(booksComplicatedFile) assert(results.schema == StructType(List( StructField(s"${DEFAULT_ATTRIBUTE_PREFIX}id", StringType, nullable = true), @@ -639,8 +670,7 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { } test("DSL test inferred schema passed through") { - val dataFrame = sqlContext - .xmlFile(carsFile) + val dataFrame = sqlContext.read.format("xml").load(carsFile) val results = dataFrame .select("comment", "year") @@ -676,16 +706,18 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { } test("DSL test with namespaces ignored") { - val results = sqlContext - .xmlFile(topicsFile, rowTag = topicsTag) + val results = sqlContext.read.format("xml") + .option("rowTag", topicsTag) + .load(topicsFile) .collect() assert(results.size === numTopics) } test("Missing nested struct represented as null instead of empty Row") { - val result = sqlContext - .xmlFile(nullNestedStructFile, rowTag = "item") + val result = sqlContext.read.format("xml") + .option("rowTag", "item") + .load(nullNestedStructFile) .select("b.es") .collect() @@ -751,18 +783,18 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { test("Empty string not allowed for rowTag, attributePrefix and valueTag.") { val messageOne = intercept[IllegalArgumentException] { - sqlContext.xmlFile(carsFile, rowTag = "").collect() + sqlContext.read.format("xml").option("rowTag", "").load(carsFile) }.getMessage assert(messageOne == "requirement failed: 'rowTag' option should not be empty string.") val messageTwo = intercept[IllegalArgumentException] { - sqlContext.xmlFile(carsFile, attributePrefix = "").collect() + sqlContext.read.format("xml").option("attributePrefix", "").load(carsFile) }.getMessage assert( messageTwo == "requirement failed: 'attributePrefix' option should not be empty string.") val messageThree = intercept[IllegalArgumentException] { - sqlContext.xmlFile(carsFile, valueTag = "").collect() + sqlContext.read.format("xml").option("valueTag", "").load(carsFile) }.getMessage assert(messageThree == "requirement failed: 'valueTag' option should not be empty string.") } diff --git a/src/test/scala/com/databricks/spark/xml/util/TypeCastSuite.scala b/src/test/scala/com/databricks/spark/xml/util/TypeCastSuite.scala index f2bc420f..03f455c8 100644 --- a/src/test/scala/com/databricks/spark/xml/util/TypeCastSuite.scala +++ b/src/test/scala/com/databricks/spark/xml/util/TypeCastSuite.scala @@ -28,7 +28,7 @@ class TypeCastSuite extends FunSuite { test("Can parse decimal type values") { val stringValues = Seq("10.05", "1,000.01", "158,058,049.001") val decimalValues = Seq(10.05, 1000.01, 158058049.001) - val decimalType = new DecimalType(None) + val decimalType = DecimalType.SYSTEM_DEFAULT stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => assert(TypeCast.castTo(strVal, decimalType) === new BigDecimal(decimalVal.toString))