diff --git a/.travis.yml b/.travis.yml
index ba5b5d44..ba009c73 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -5,37 +5,13 @@ cache:
- $HOME/.ivy2
matrix:
include:
- # Spark 1.3.0
- - jdk: openjdk6
- scala: 2.10.5
- env: TEST_HADOOP_VERSION="1.2.1" TEST_SPARK_VERSION="1.3.0"
- - jdk: openjdk6
- scala: 2.11.7
- env: TEST_HADOOP_VERSION="1.0.4" TEST_SPARK_VERSION="1.3.0"
- # Spark 1.4.1
- # We only test Spark 1.4.1 with Hadooop 2.2.0 because
- # https://github.com/apache/spark/pull/6599 is not present in 1.4.1,
- # so the published Spark Maven artifacts will not work with Hadoop 1.x.
- - jdk: openjdk6
- scala: 2.10.5
- env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="1.4.1"
- - jdk: openjdk7
- scala: 2.11.7
- env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="1.4.1"
- # Spark 1.5.0
- - jdk: openjdk7
- scala: 2.10.5
- env: TEST_HADOOP_VERSION="1.0.4" TEST_SPARK_VERSION="1.5.0"
- - jdk: openjdk7
- scala: 2.11.7
- env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="1.5.0"
- # Spark 1.6.0
+ # Spark 2.0.0
- jdk: openjdk7
scala: 2.10.5
- env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="1.6.0"
+ env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.0.0"
- jdk: openjdk7
scala: 2.11.7
- env: TEST_HADOOP_VERSION="1.2.1" TEST_SPARK_VERSION="1.6.0"
+ env: TEST_HADOOP_VERSION="2.6.0" TEST_SPARK_VERSION="2.0.0"
script:
- sbt -Dhadoop.testVersion=$TEST_HADOOP_VERSION -Dspark.testVersion=$TEST_SPARK_VERSION ++$TRAVIS_SCALA_VERSION coverage test
- sbt ++$TRAVIS_SCALA_VERSION assembly
diff --git a/README.md b/README.md
index c53c654d..2ec1d6aa 100644
--- a/README.md
+++ b/README.md
@@ -496,4 +496,3 @@ This library is built with [SBT](http://www.scala-sbt.org/0.13/docs/Command-Line
## Acknowledgements
This project was initially created by [HyukjinKwon](https://github.com/HyukjinKwon) and donated to [Databricks](https://databricks.com).
-
diff --git a/build.sbt b/build.sbt
index ca008e3f..5a8c5e51 100755
--- a/build.sbt
+++ b/build.sbt
@@ -1,6 +1,6 @@
name := "spark-xml"
-version := "0.3.4"
+version := "0.4.0-SNAPSHOT"
organization := "com.databricks"
@@ -10,7 +10,7 @@ spName := "databricks/spark-xml"
crossScalaVersions := Seq("2.10.5", "2.11.7")
-sparkVersion := "1.6.0"
+sparkVersion := "2.0.0"
val testSparkVersion = settingKey[String]("The version of Spark to test against.")
@@ -75,26 +75,3 @@ ScoverageSbtPlugin.ScoverageKeys.coverageHighlighting := {
if (scalaBinaryVersion.value == "2.10") false
else true
}
-
-// -- MiMa binary compatibility checks ------------------------------------------------------------
-
-//import com.typesafe.tools.mima.core._
-//import com.typesafe.tools.mima.plugin.MimaKeys.binaryIssueFilters
-//import com.typesafe.tools.mima.plugin.MimaKeys.previousArtifact
-//import com.typesafe.tools.mima.plugin.MimaPlugin.mimaDefaultSettings
-//
-//mimaDefaultSettings ++ Seq(
-// previousArtifact := Some("org.apache" %% "spark-xml" % "1.2.0"),
-// binaryIssueFilters ++= Seq(
-// // These classes are not intended to be public interfaces:
-// ProblemFilters.excludePackage("org.apache.spark.xml.XmlRelation"),
-// ProblemFilters.excludePackage("org.apache.spark.xml.util.InferSchema"),
-// ProblemFilters.excludePackage("org.apache.spark.sql.readers"),
-// ProblemFilters.excludePackage("org.apache.spark.xml.util.TypeCast"),
-// // We allowed the private `XmlRelation` type to leak into the public method signature:
-// ProblemFilters.exclude[IncompatibleResultTypeProblem](
-// "org.apache.spark.xml.DefaultSource.createRelation")
-// )
-//)
-
-// ------------------------------------------------------------------------------------------------
diff --git a/src/main/scala/com/databricks/spark/xml/DefaultSource.scala b/src/main/scala/com/databricks/spark/xml/DefaultSource.scala
index 796ca2dc..49fe25c0 100755
--- a/src/main/scala/com/databricks/spark/xml/DefaultSource.scala
+++ b/src/main/scala/com/databricks/spark/xml/DefaultSource.scala
@@ -20,7 +20,6 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
-import com.databricks.spark.xml.util.CompressionCodecs
import com.databricks.spark.xml.util.XmlFile
/**
@@ -90,9 +89,7 @@ class DefaultSource
}
if (doSave) {
// Only save data when the save mode is not ignore.
- val codecClass =
- CompressionCodecs.getCodecClass(XmlOptions(parameters).codec)
- data.saveAsXmlFile(filesystemPath.toString, parameters, codecClass)
+ XmlFile.saveAsXmlFile(data, filesystemPath.toString, parameters)
}
createRelation(sqlContext, parameters, data.schema)
}
diff --git a/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala b/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala
index c124b11e..66bd7272 100644
--- a/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala
+++ b/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala
@@ -70,12 +70,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = {
val fileSplit: FileSplit = split.asInstanceOf[FileSplit]
- val conf: Configuration = {
- // Use reflection to get the Configuration. This is necessary because TaskAttemptContext is
- // a class in Hadoop 1.x and an interface in Hadoop 2.x.
- val method = context.getClass.getMethod("getConfiguration")
- method.invoke(context).asInstanceOf[Configuration]
- }
+ val conf: Configuration = context.getConfiguration
val charset =
Charset.forName(conf.get(XmlInputFormat.ENCODING_KEY, XmlOptions.DEFAULT_CHARSET))
startTag = conf.get(XmlInputFormat.START_TAG_KEY).getBytes(charset)
@@ -97,25 +92,18 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
val codec = new CompressionCodecFactory(conf).getCodec(path)
if (null != codec) {
decompressor = CodecPool.getDecompressor(codec)
- // Use reflection to get the splittable compression codec and stream. This is necessary
- // because SplittableCompressionCodec does not exist in Hadoop 1.0.x.
- def isSplitCompressionCodec(obj: Any) = {
- val splittableClassName = "org.apache.hadoop.io.compress.SplittableCompressionCodec"
- obj.getClass.getInterfaces.map(_.getName).contains(splittableClassName)
- }
- // Here I made separate variables to avoid to try to find SplitCompressionInputStream at
- // runtime.
- val (inputStream, seekable) = codec match {
- case c: CompressionCodec if isSplitCompressionCodec(c) =>
- // At Hadoop 1.0.x, this case would not be executed.
- val cIn = {
- val sc = c.asInstanceOf[SplittableCompressionCodec]
- sc.createInputStream(fsin, decompressor, start,
- end, SplittableCompressionCodec.READ_MODE.BYBLOCK)
- }
+ codec match {
+ case sc: SplittableCompressionCodec =>
+ val cIn = sc.createInputStream(
+ fsin,
+ decompressor,
+ start,
+ end,
+ SplittableCompressionCodec.READ_MODE.BYBLOCK)
start = cIn.getAdjustedStart
end = cIn.getAdjustedEnd
- (cIn, cIn)
+ in = cIn
+ filePosition = cIn
case c: CompressionCodec =>
if (start != 0) {
// So we have a split that is only part of a file stored using
@@ -124,10 +112,9 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
codec.getClass.getSimpleName + " compressed stream")
}
val cIn = c.createInputStream(fsin, decompressor)
- (cIn, fsin)
+ in = cIn
+ filePosition = fsin
}
- in = inputStream
- filePosition = seekable
} else {
in = fsin
filePosition = fsin
diff --git a/src/main/scala/com/databricks/spark/xml/XmlRelation.scala b/src/main/scala/com/databricks/spark/xml/XmlRelation.scala
index 129f5b38..f72e5150 100755
--- a/src/main/scala/com/databricks/spark/xml/XmlRelation.scala
+++ b/src/main/scala/com/databricks/spark/xml/XmlRelation.scala
@@ -24,7 +24,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.sources.{PrunedScan, InsertableRelation, BaseRelation, TableScan}
import org.apache.spark.sql.types._
-import com.databricks.spark.xml.util.{CompressionCodecs, InferSchema}
+import com.databricks.spark.xml.util.{InferSchema, XmlFile}
import com.databricks.spark.xml.parsers.StaxXmlParser
case class XmlRelation protected[spark] (
@@ -90,8 +90,7 @@ case class XmlRelation protected[spark] (
+ s" to INSERT OVERWRITE a XML table:\n${e.toString}")
}
// Write the data. We assume that schema isn't changed, and we won't update it.
- val codecClass = CompressionCodecs.getCodecClass(options.codec)
- data.saveAsXmlFile(filesystemPath.toString, parameters, codecClass)
+ XmlFile.saveAsXmlFile(data, filesystemPath.toString, parameters)
} else {
sys.error("XML tables only support INSERT OVERWRITE for now.")
}
diff --git a/src/main/scala/com/databricks/spark/xml/package.scala b/src/main/scala/com/databricks/spark/xml/package.scala
index 5eae2afd..b0b10994 100755
--- a/src/main/scala/com/databricks/spark/xml/package.scala
+++ b/src/main/scala/com/databricks/spark/xml/package.scala
@@ -15,23 +15,19 @@
*/
package com.databricks.spark
-import java.io.CharArrayWriter
-import javax.xml.stream.XMLOutputFactory
-
import scala.collection.Map
-import com.sun.xml.internal.txw2.output.IndentingXMLStreamWriter
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.spark.sql.{DataFrame, SQLContext}
import com.databricks.spark.xml.util.XmlFile
-import com.databricks.spark.xml.parsers.StaxXmlGenerator
package object xml {
/**
* Adds a method, `xmlFile`, to [[SQLContext]] that allows reading XML data.
*/
implicit class XmlContext(sqlContext: SQLContext) extends Serializable {
+ @deprecated("Use DataFrameReader.read()", "0.4.0")
def xmlFile(
filePath: String,
rowTag: String = XmlOptions.DEFAULT_ROW_TAG,
@@ -82,69 +78,16 @@ package object xml {
//
//
// Namely, roundtrip in writing and reading can end up in different schema structure.
+ @deprecated("Use DataFrameWriter.write()", "0.4.0")
def saveAsXmlFile(
path: String, parameters: Map[String, String] = Map(),
compressionCodec: Class[_ <: CompressionCodec] = null): Unit = {
- val options = XmlOptions(parameters.toMap)
- val startElement = s"<${options.rootTag}>"
- val endElement = s"${options.rootTag}>"
- val rowSchema = dataFrame.schema
- val indent = XmlFile.DEFAULT_INDENT
- val rowSeparator = XmlFile.DEFAULT_ROW_SEPARATOR
-
- val xmlRDD = dataFrame.rdd.mapPartitions { iter =>
- val factory = XMLOutputFactory.newInstance()
- val writer = new CharArrayWriter()
- val xmlWriter = factory.createXMLStreamWriter(writer)
- val indentingXmlWriter = new IndentingXMLStreamWriter(xmlWriter)
- indentingXmlWriter.setIndentStep(indent)
-
- new Iterator[String] {
- var firstRow: Boolean = true
- var lastRow: Boolean = true
-
- override def hasNext: Boolean = iter.hasNext || firstRow || lastRow
-
- override def next: String = {
- if (iter.nonEmpty) {
- val xml = {
- StaxXmlGenerator(
- rowSchema,
- indentingXmlWriter,
- options)(iter.next())
- writer.toString
- }
- writer.reset()
-
- // Here it needs to add indentations for the start of each line,
- // in order to insert the start element and end element.
- val indentedXml = indent + xml.replaceAll(rowSeparator, rowSeparator + indent)
- if (firstRow) {
- firstRow = false
- startElement + rowSeparator + indentedXml
- } else {
- indentedXml
- }
- } else {
- indentingXmlWriter.close()
- if (!firstRow) {
- lastRow = false
- endElement
- } else {
- // This means the iterator was initially empty.
- firstRow = false
- lastRow = false
- ""
- }
- }
- }
- }
- }
-
- compressionCodec match {
- case null => xmlRDD.saveAsTextFile(path)
- case codec => xmlRDD.saveAsTextFile(path, codec)
- }
+ val mutableParams = collection.mutable.Map(parameters.toSeq: _*)
+ val safeCodec = mutableParams.get("codec")
+ .orElse(Option(compressionCodec).map(_.getCanonicalName))
+ .orNull
+ mutableParams.put("codec", safeCodec)
+ XmlFile.saveAsXmlFile(dataFrame, path, mutableParams.toMap)
}
}
}
diff --git a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlGenerator.scala b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlGenerator.scala
index 8d14ff94..cfe08024 100644
--- a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlGenerator.scala
+++ b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlGenerator.scala
@@ -85,7 +85,6 @@ private[xml] object StaxXmlGenerator {
case (ByteType, v: Byte) => writer.writeCharacters(v.toString)
case (BooleanType, v: Boolean) => writer.writeCharacters(v.toString)
case (DateType, v) => writer.writeCharacters(v.toString)
- case (udt: UserDefinedType[_], v) => writeElement(udt.sqlType, udt.serialize(v))
// For the case roundtrip in reading and writing XML files, [[ArrayType]] cannot have
// [[ArrayType]] as element type. It always wraps the element with [[StructType]]. So,
diff --git a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala
index 404ac0fd..88544977 100644
--- a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala
+++ b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala
@@ -106,7 +106,6 @@ private[xml] object StaxXmlParser {
case dt: StructType => convertObject(parser, dt, options)
case MapType(StringType, vt, _) => convertMap(parser, vt, options)
case ArrayType(st, _) => convertField(parser, st, options)
- case udt: UserDefinedType[_] => convertField(parser, udt.sqlType, options)
case _: StringType => StaxXmlParserUtils.currentStructureAsString(parser)
}
@@ -153,7 +152,7 @@ private[xml] object StaxXmlParser {
case (v, ByteType) => castTo(v, ByteType)
case (v, ShortType) => castTo(v, ShortType)
case (v, IntegerType) => signSafeToInt(v)
- case (v, _: DecimalType) => castTo(v, new DecimalType(None))
+ case (v, dt: DecimalType) => castTo(v, dt)
case (_, dataType) =>
sys.error(s"Failed to parse a value for data type $dataType.")
}
diff --git a/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala b/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala
index ad4471b6..a3649a37 100644
--- a/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala
+++ b/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala
@@ -36,7 +36,7 @@ private[xml] object InferSchema {
/**
* Copied from internal Spark api
- * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]]
+ * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion]]
*/
private val numericPrecedence: IndexedSeq[DataType] =
IndexedSeq[DataType](
@@ -47,7 +47,7 @@ private[xml] object InferSchema {
FloatType,
DoubleType,
TimestampType,
- DecimalType.Unlimited)
+ DecimalType.SYSTEM_DEFAULT)
val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = {
case (t1, t2) if t1 == t2 => Some(t1)
diff --git a/src/main/scala/com/databricks/spark/xml/util/XmlFile.scala b/src/main/scala/com/databricks/spark/xml/util/XmlFile.scala
index 5abb08b2..155f7edd 100644
--- a/src/main/scala/com/databricks/spark/xml/util/XmlFile.scala
+++ b/src/main/scala/com/databricks/spark/xml/util/XmlFile.scala
@@ -15,12 +15,19 @@
*/
package com.databricks.spark.xml.util
+import java.io.CharArrayWriter
import java.nio.charset.Charset
+import javax.xml.stream.XMLOutputFactory
+import scala.collection.Map
+
+import com.databricks.spark.xml.parsers.StaxXmlGenerator
+import com.sun.xml.internal.txw2.output.IndentingXMLStreamWriter
import org.apache.hadoop.io.{Text, LongWritable}
-import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.DataFrame
import com.databricks.spark.xml.{XmlOptions, XmlInputFormat}
private[xml] object XmlFile {
@@ -47,4 +54,90 @@ private[xml] object XmlFile {
classOf[Text]).map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))
}
}
+
+ /**
+ * Note that writing a XML file from [[DataFrame]] having a field
+ * [[org.apache.spark.sql.types.ArrayType]] with its element as nested array would have
+ * an additional nested field for the element. For example, the [[DataFrame]] having
+ * a field below,
+ *
+ * fieldA Array(Array(data1, data2))
+ *
+ * would produce a XML file below.
+ *
+ *
+ * - data1
+ *
+ *
+ * - data2
+ *
+ *
+ * Namely, roundtrip in writing and reading can end up in different schema structure.
+ */
+ def saveAsXmlFile(
+ dataFrame: DataFrame,
+ path: String,
+ parameters: Map[String, String] = Map()): Unit = {
+ val options = XmlOptions(parameters.toMap)
+ val codecClass = CompressionCodecs.getCodecClass(options.codec)
+ val startElement = s"<${options.rootTag}>"
+ val endElement = s"${options.rootTag}>"
+ val rowSchema = dataFrame.schema
+ val indent = XmlFile.DEFAULT_INDENT
+ val rowSeparator = XmlFile.DEFAULT_ROW_SEPARATOR
+
+ val xmlRDD = dataFrame.rdd.mapPartitions { iter =>
+ val factory = XMLOutputFactory.newInstance()
+ val writer = new CharArrayWriter()
+ val xmlWriter = factory.createXMLStreamWriter(writer)
+ val indentingXmlWriter = new IndentingXMLStreamWriter(xmlWriter)
+ indentingXmlWriter.setIndentStep(indent)
+
+ new Iterator[String] {
+ var firstRow: Boolean = true
+ var lastRow: Boolean = true
+
+ override def hasNext: Boolean = iter.hasNext || firstRow || lastRow
+
+ override def next: String = {
+ if (iter.nonEmpty) {
+ val xml = {
+ StaxXmlGenerator(
+ rowSchema,
+ indentingXmlWriter,
+ options)(iter.next())
+ writer.toString
+ }
+ writer.reset()
+
+ // Here it needs to add indentations for the start of each line,
+ // in order to insert the start element and end element.
+ val indentedXml = indent + xml.replaceAll(rowSeparator, rowSeparator + indent)
+ if (firstRow) {
+ firstRow = false
+ startElement + rowSeparator + indentedXml
+ } else {
+ indentedXml
+ }
+ } else {
+ indentingXmlWriter.close()
+ if (!firstRow) {
+ lastRow = false
+ endElement
+ } else {
+ // This means the iterator was initially empty.
+ firstRow = false
+ lastRow = false
+ ""
+ }
+ }
+ }
+ }
+ }
+
+ codecClass match {
+ case null => xmlRDD.saveAsTextFile(path)
+ case codec => xmlRDD.saveAsTextFile(path, codec)
+ }
+ }
}
diff --git a/src/test/java/com/databricks/spark/xml/JavaXmlSuite.java b/src/test/java/com/databricks/spark/xml/JavaXmlSuite.java
index 68a6abb1..2a105cb8 100644
--- a/src/test/java/com/databricks/spark/xml/JavaXmlSuite.java
+++ b/src/test/java/com/databricks/spark/xml/JavaXmlSuite.java
@@ -49,9 +49,9 @@ public void tearDown() {
@Test
public void testXmlParser() {
- DataFrame df = (new XmlReader()).withRowTag(booksFileTag).xmlFile(sqlContext, booksFile);
+ Dataset df = (new XmlReader()).withRowTag(booksFileTag).xmlFile(sqlContext, booksFile);
String prefix = XmlOptions.DEFAULT_ATTRIBUTE_PREFIX();
- int result = df.select(prefix + "id").collect().length;
+ long result = df.select(prefix + "id").count();
Assert.assertEquals(result, numBooks);
}
@@ -61,19 +61,19 @@ public void testLoad() {
options.put("rowTag", booksFileTag);
options.put("path", booksFile);
- DataFrame df = sqlContext.load("com.databricks.spark.xml", options);
- int result = df.select("description").collect().length;
+ Dataset df = sqlContext.load("com.databricks.spark.xml", options);
+ long result = df.select("description").count();
Assert.assertEquals(result, numBooks);
}
@Test
public void testSave() {
- DataFrame df = (new XmlReader()).withRowTag(booksFileTag).xmlFile(sqlContext, booksFile);
+ Dataset df = (new XmlReader()).withRowTag(booksFileTag).xmlFile(sqlContext, booksFile);
TestUtils.deleteRecursively(new File(tempDir));
- df.select("price", "description").save(tempDir, "com.databricks.spark.xml");
+ df.select("price", "description").write().format("xml").save(tempDir);
- DataFrame newDf = (new XmlReader()).xmlFile(sqlContext, tempDir);
- int result = newDf.select("price").collect().length;
+ Dataset newDf = (new XmlReader()).xmlFile(sqlContext, tempDir);
+ long result = newDf.select("price").count();
Assert.assertEquals(result, numBooks);
}
}
diff --git a/src/test/scala/com/databricks/spark/xml/XmlSuite.scala b/src/test/scala/com/databricks/spark/xml/XmlSuite.scala
index 623533ad..d1229c83 100755
--- a/src/test/scala/com/databricks/spark/xml/XmlSuite.scala
+++ b/src/test/scala/com/databricks/spark/xml/XmlSuite.scala
@@ -85,8 +85,8 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
}
test("DSL test") {
- val results = sqlContext
- .xmlFile(carsFile)
+ val results = sqlContext.read.format("xml")
+ .load(carsFile)
.select("year")
.collect()
@@ -94,15 +94,16 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
}
test("DSL test with xml having unbalanced datatypes") {
- val results = sqlContext
- .xmlFile(gpsEmptyField, treatEmptyValuesAsNulls = true)
+ val results = sqlContext.read.format("xml")
+ .option("treatEmptyValuesAsNulls", "true")
+ .load(gpsEmptyField)
assert(results.collect().size === numGPS)
}
test("DSL test with mixed elements (attributes, no child)") {
- val results = sqlContext
- .xmlFile(carsMixedAttrNoChildFile)
+ val results = sqlContext.read.format("xml")
+ .load(carsMixedAttrNoChildFile)
.select("date")
.collect()
@@ -114,8 +115,9 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
}
test("DSL test for inconsistent element attributes as fields") {
- val results = sqlContext
- .xmlFile(booksAttributesInNoChild, rowTag = booksTag)
+ val results = sqlContext.read.format("xml")
+ .option("rowTag", booksTag)
+ .load(booksAttributesInNoChild)
.select("price")
// This should not throw an exception `java.lang.ArrayIndexOutOfBoundsException`
@@ -130,13 +132,18 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
}
test("DSL test with mixed elements (struct, string)") {
- val results = sqlContext
- .xmlFile(agesMixedTypes, rowTag = agesTag).collect()
+ val results = sqlContext.read.format("xml")
+ .option("rowTag", agesTag)
+ .load(agesMixedTypes)
+ .collect()
assert(results.size === numAges)
}
test("DSL test with elements in array having attributes") {
- val results = sqlContext.xmlFile(agesFile, rowTag = agesTag).collect()
+ val results = sqlContext.read.format("xml")
+ .option("rowTag", agesTag)
+ .load(agesFile)
+ .collect()
val attrValOne = results(0).get(0).asInstanceOf[Row](1)
val attrValTwo = results(1).get(0).asInstanceOf[Row](1)
assert(attrValOne == "1990-02-24")
@@ -158,8 +165,8 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
}
test("DSL test compressed file") {
- val results = sqlContext
- .xmlFile(carsFileGzip)
+ val results = sqlContext.read.format("xml")
+ .load(carsFileGzip)
.select("year")
.collect()
@@ -167,8 +174,8 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
}
test("DSL test splittable compressed file") {
- val results = sqlContext
- .xmlFile(carsFileBzip2)
+ val results = sqlContext.read.format("xml")
+ .load(carsFileBzip2)
.select("year")
.collect()
@@ -177,8 +184,9 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
test("DSL test bad charset name") {
val exception = intercept[UnsupportedCharsetException] {
- val results = sqlContext
- .xmlFile(carsFile, charset = "1-9588-osi")
+ val results = sqlContext.read.format("xml")
+ .option("charset", "1-9588-osi")
+ .load(carsFile)
.select("year")
.collect()
}
@@ -333,14 +341,17 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
new File(tempEmptyDir).mkdirs()
val copyFilePath = tempEmptyDir + "cars-copy.xml"
- val cars = sqlContext.xmlFile(carsFile)
- cars.save("com.databricks.spark.xml", SaveMode.Overwrite,
- Map("path" -> copyFilePath, "codec" -> classOf[GzipCodec].getName))
+ val cars = sqlContext.read.format("xml").load(carsFile)
+ cars.write
+ .format("xml")
+ .mode(SaveMode.Overwrite)
+ .options(Map("path" -> copyFilePath, "codec" -> classOf[GzipCodec].getName))
+ .save(copyFilePath)
val carsCopyPartFile = new File(copyFilePath, "part-00000.gz")
// Check that the part file has a .gz extension
assert(carsCopyPartFile.exists())
- val carsCopy = sqlContext.xmlFile(copyFilePath + "/")
+ val carsCopy = sqlContext.read.format("xml").load(copyFilePath)
assert(carsCopy.count == cars.count)
assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet)
@@ -352,14 +363,18 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
new File(tempEmptyDir).mkdirs()
val copyFilePath = tempEmptyDir + "cars-copy.xml"
- val cars = sqlContext.xmlFile(carsFile)
- cars.save("com.databricks.spark.xml", SaveMode.Overwrite,
- Map("path" -> copyFilePath, "compression" -> "gZiP"))
+ val cars = sqlContext.read.format("xml").load(carsFile)
+ cars.write
+ .format("xml")
+ .mode(SaveMode.Overwrite)
+ .options(Map("path" -> copyFilePath, "compression" -> "gZiP"))
+ .save(copyFilePath)
+
val carsCopyPartFile = new File(copyFilePath, "part-00000.gz")
// Check that the part file has a .gz extension
assert(carsCopyPartFile.exists())
- val carsCopy = sqlContext.xmlFile(copyFilePath)
+ val carsCopy = sqlContext.read.format("xml").load(copyFilePath)
assert(carsCopy.count == cars.count)
assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet)
@@ -371,11 +386,17 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
new File(tempEmptyDir).mkdirs()
val copyFilePath = tempEmptyDir + "books-copy.xml"
- val books = sqlContext.xmlFile(booksComplicatedFile, rowTag = booksTag)
- books.saveAsXmlFile(copyFilePath,
- Map("rootTag" -> booksRootTag, "rowTag" -> booksTag))
-
- val booksCopy = sqlContext.xmlFile(copyFilePath + "/", rowTag = booksTag)
+ val books = sqlContext.read.format("xml")
+ .option("rowTag", booksTag)
+ .load(booksComplicatedFile)
+ books.write
+ .options(Map("rootTag" -> booksRootTag, "rowTag" -> booksTag))
+ .format("xml")
+ .save(copyFilePath)
+
+ val booksCopy = sqlContext.read.format("xml")
+ .option("rowTag", booksTag)
+ .load(copyFilePath)
assert(booksCopy.count == books.count)
assert(booksCopy.collect.map(_.toString).toSet === books.collect.map(_.toString).toSet)
}
@@ -386,12 +407,18 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
new File(tempEmptyDir).mkdirs()
val copyFilePath = tempEmptyDir + "books-copy.xml"
- val books = sqlContext.xmlFile(booksComplicatedFile, rowTag = booksTag)
- books.saveAsXmlFile(copyFilePath,
- Map("rootTag" -> booksRootTag, "rowTag" -> booksTag, "nullValue" -> ""))
+ val books = sqlContext.read.format("xml")
+ .option("rowTag", booksTag)
+ .load(booksComplicatedFile)
+ books.write
+ .options(Map("rootTag" -> booksRootTag, "rowTag" -> booksTag, "nullValue" -> ""))
+ .format("xml")
+ .save(copyFilePath)
- val booksCopy =
- sqlContext.xmlFile(copyFilePath, rowTag = booksTag, treatEmptyValuesAsNulls = true)
+ val booksCopy = sqlContext.read.format("xml")
+ .option("rowTag", booksTag)
+ .option("treatEmptyValuesAsNulls", "true")
+ .load(copyFilePath)
assert(booksCopy.count == books.count)
assert(booksCopy.collect.map(_.toString).toSet === books.collect.map(_.toString).toSet)
@@ -408,7 +435,7 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
val data = sqlContext.sparkContext.parallelize(
List(List(List("aa", "bb"), List("aa", "bb")))).map(Row(_))
val df = sqlContext.createDataFrame(data, schema)
- df.saveAsXmlFile(copyFilePath)
+ df.write.format("xml").save(copyFilePath)
// When [[ArrayType]] has [[ArrayType]] as elements, it is confusing what is the element
// name for XML file. Now, it is "item". So, "item" field is additionally added
@@ -416,8 +443,8 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
val schemaCopy = StructType(
List(StructField("a", ArrayType(
StructType(List(StructField("item", ArrayType(StringType), nullable = true)))),
- nullable = true)))
- val dfCopy = sqlContext.xmlFile(copyFilePath + "/")
+ nullable = true)))
+ val dfCopy = sqlContext.read.format("xml").load(copyFilePath)
assert(dfCopy.count == df.count)
assert(dfCopy.schema === schemaCopy)
@@ -453,19 +480,18 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
val data = sqlContext.sparkContext.parallelize(Seq(row))
val df = sqlContext.createDataFrame(data, schema)
- df.saveAsXmlFile(copyFilePath)
+ df.write.format("xml").save(copyFilePath)
val dfCopy = new XmlReader()
.withSchema(schema)
- .xmlFile(sqlContext, copyFilePath + "/")
+ .xmlFile(sqlContext, copyFilePath)
assert(dfCopy.collect() === df.collect())
assert(dfCopy.schema === df.schema)
}
test("DSL test schema inferred correctly") {
- val results = sqlContext
- .xmlFile(booksFile, rowTag = booksTag)
+ val results = sqlContext.read.format("xml").option("rowTag", booksTag).load(booksFile)
assert(results.schema == StructType(List(
StructField(s"${DEFAULT_ATTRIBUTE_PREFIX}id", StringType, nullable = true),
@@ -481,8 +507,10 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
}
test("DSL test schema inferred correctly with sampling ratio") {
- val results = sqlContext
- .xmlFile(booksFile, rowTag = booksTag, samplingRatio = 0.5)
+ val results = sqlContext.read.format("xml")
+ .option("rowTag", booksTag)
+ .option("samplingRatio", 0.5)
+ .load(booksFile)
assert(results.schema == StructType(List(
StructField(s"${DEFAULT_ATTRIBUTE_PREFIX}id", StringType, nullable = true),
@@ -498,8 +526,9 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
}
test("DSL test schema (object) inferred correctly") {
- val results = sqlContext
- .xmlFile(booksNestedObjectFile, rowTag = booksTag)
+ val results = sqlContext.read.format("xml")
+ .option("rowTag", booksTag)
+ .load(booksNestedObjectFile)
assert(results.schema == StructType(List(
StructField(s"${DEFAULT_ATTRIBUTE_PREFIX}id", StringType, nullable = true),
@@ -516,8 +545,9 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
}
test("DSL test schema (array) inferred correctly") {
- val results = sqlContext
- .xmlFile(booksNestedArrayFile, rowTag = booksTag)
+ val results = sqlContext.read.format("xml")
+ .option("rowTag", booksTag)
+ .load(booksNestedArrayFile)
assert(results.schema == StructType(List(
StructField(s"${DEFAULT_ATTRIBUTE_PREFIX}id", StringType, nullable = true),
@@ -533,8 +563,9 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
}
test("DSL test schema (complicated) inferred correctly") {
- val results = sqlContext
- .xmlFile(booksComplicatedFile, rowTag = booksTag)
+ val results = sqlContext.read.format("xml")
+ .option("rowTag", booksTag)
+ .load(booksComplicatedFile)
assert(results.schema == StructType(List(
StructField(s"${DEFAULT_ATTRIBUTE_PREFIX}id", StringType, nullable = true),
@@ -639,8 +670,7 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
}
test("DSL test inferred schema passed through") {
- val dataFrame = sqlContext
- .xmlFile(carsFile)
+ val dataFrame = sqlContext.read.format("xml").load(carsFile)
val results = dataFrame
.select("comment", "year")
@@ -676,16 +706,18 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
}
test("DSL test with namespaces ignored") {
- val results = sqlContext
- .xmlFile(topicsFile, rowTag = topicsTag)
+ val results = sqlContext.read.format("xml")
+ .option("rowTag", topicsTag)
+ .load(topicsFile)
.collect()
assert(results.size === numTopics)
}
test("Missing nested struct represented as null instead of empty Row") {
- val result = sqlContext
- .xmlFile(nullNestedStructFile, rowTag = "item")
+ val result = sqlContext.read.format("xml")
+ .option("rowTag", "item")
+ .load(nullNestedStructFile)
.select("b.es")
.collect()
@@ -751,18 +783,18 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
test("Empty string not allowed for rowTag, attributePrefix and valueTag.") {
val messageOne = intercept[IllegalArgumentException] {
- sqlContext.xmlFile(carsFile, rowTag = "").collect()
+ sqlContext.read.format("xml").option("rowTag", "").load(carsFile)
}.getMessage
assert(messageOne == "requirement failed: 'rowTag' option should not be empty string.")
val messageTwo = intercept[IllegalArgumentException] {
- sqlContext.xmlFile(carsFile, attributePrefix = "").collect()
+ sqlContext.read.format("xml").option("attributePrefix", "").load(carsFile)
}.getMessage
assert(
messageTwo == "requirement failed: 'attributePrefix' option should not be empty string.")
val messageThree = intercept[IllegalArgumentException] {
- sqlContext.xmlFile(carsFile, valueTag = "").collect()
+ sqlContext.read.format("xml").option("valueTag", "").load(carsFile)
}.getMessage
assert(messageThree == "requirement failed: 'valueTag' option should not be empty string.")
}
diff --git a/src/test/scala/com/databricks/spark/xml/util/TypeCastSuite.scala b/src/test/scala/com/databricks/spark/xml/util/TypeCastSuite.scala
index f2bc420f..03f455c8 100644
--- a/src/test/scala/com/databricks/spark/xml/util/TypeCastSuite.scala
+++ b/src/test/scala/com/databricks/spark/xml/util/TypeCastSuite.scala
@@ -28,7 +28,7 @@ class TypeCastSuite extends FunSuite {
test("Can parse decimal type values") {
val stringValues = Seq("10.05", "1,000.01", "158,058,049.001")
val decimalValues = Seq(10.05, 1000.01, 158058049.001)
- val decimalType = new DecimalType(None)
+ val decimalType = DecimalType.SYSTEM_DEFAULT
stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) =>
assert(TypeCast.castTo(strVal, decimalType) === new BigDecimal(decimalVal.toString))