Skip to content

Commit

Permalink
#417 SparkXML-related unit test added first (regression guard), Spark…
Browse files Browse the repository at this point in the history
…XMLHack removed, the test holds.
  • Loading branch information
dk1844 committed May 12, 2021
1 parent d3b57a0 commit 603f3ac
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.slf4j.{Logger, LoggerFactory}
import za.co.absa.enceladus.common.{Constants, RecordIdGeneration}
import za.co.absa.enceladus.common.RecordIdGeneration._
import za.co.absa.enceladus.standardization.interpreter.dataTypes._
import za.co.absa.enceladus.standardization.interpreter.stages.{SchemaChecker, SparkXMLHack, TypeParser}
import za.co.absa.enceladus.standardization.interpreter.stages.{SchemaChecker, TypeParser}
import za.co.absa.enceladus.utils.error.ErrorMessage
import za.co.absa.enceladus.utils.schema.{SchemaUtils, SparkUtils}
import za.co.absa.enceladus.utils.transformations.ArrayTransformations
Expand Down Expand Up @@ -54,17 +54,8 @@ object StandardizationInterpreter {
logger.info(s"Step 1: Schema validation")
validateSchemaAgainstSelfInconsistencies(expSchema)

// TODO: remove when spark-xml handles empty arrays #417
val dfXmlSafe: Dataset[Row] = if (inputType.toLowerCase() == "xml") {
df.select(expSchema.fields.map { field: StructField =>
SparkXMLHack.hack(field, "", df).as(field.name)
}: _*)
} else {
df
}

logger.info(s"Step 2: Standardization")
val std = standardizeDataset(dfXmlSafe, expSchema, failOnInputNotPerSchema)
val std = standardizeDataset(df, expSchema, failOnInputNotPerSchema)

logger.info(s"Step 3: Clean the final error column")
val cleanedStd = cleanTheFinalErrorColumn(std)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
<instrument><reportDate>2018-08-10</reportDate><rowId>1</rowId><legs><leg><price>1000</price></leg></legs></instrument>
<instrument><reportDate>2018-08-10</reportDate><rowId>2</rowId><legs><leg><price>2000</price></leg></legs></instrument>
<instrument><reportDate>2018-08-10</reportDate><rowId>3</rowId><legs></legs></instrument>
<instrument><reportDate>2018-08-10</reportDate><rowId>4</rowId></instrument>
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2018 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.enceladus.standardization

import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types
import org.apache.spark.sql.types._
import org.mockito.scalatest.MockitoSugar
import org.scalatest.funsuite.AnyFunSuite
import za.co.absa.enceladus.dao.MenasDAO
import za.co.absa.enceladus.model.Dataset
import za.co.absa.enceladus.standardization.config.StandardizationConfig
import za.co.absa.enceladus.standardization.interpreter.StandardizationInterpreter
import za.co.absa.enceladus.standardization.interpreter.stages.PlainSchemaGenerator
import za.co.absa.enceladus.utils.implicits.DataFrameImplicits.DataFrameEnhancements
import za.co.absa.enceladus.utils.testUtils.SparkTestBase
import za.co.absa.enceladus.utils.udf.UDFLibrary

class StandardizationXmlSuite extends AnyFunSuite with SparkTestBase with MockitoSugar{
private implicit val udfLibrary:UDFLibrary = new UDFLibrary()

private val standardizationReader = new StandardizationPropertiesProvider()

test("Reading data from XML input") {

implicit val dao: MenasDAO = mock[MenasDAO]

val args = ("--dataset-name Foo --dataset-version 1 --report-date 2018-08-10 --report-version 1 " +
"--menas-auth-keytab src/test/resources/user.keytab.example " +
"--raw-format xml --row-tag instrument").split(" ")

val dataSet = Dataset("SpecialChars", 1, None, "", "", "SpecialChars", 1, conformance = Nil)
val cmd = StandardizationConfig.getFromArguments(args)

val csvReader = standardizationReader.getFormatSpecificReader(cmd, dataSet)

val baseSchema = StructType(Array(
StructField("rowId", LongType),
StructField("reportDate", StringType),
StructField("legs", types.ArrayType(StructType(Array(
StructField("leg", StructType(Array(
StructField("price", IntegerType)
)))
))))
))
val inputSchema = PlainSchemaGenerator.generateInputSchema(baseSchema, Option("_corrupt_record"))
val reader = csvReader.schema(inputSchema)

val sourceDF = reader.load("src/test/resources/data/standardization_xml_suite_data.txt")
// not expecting corrupted records, but checking to be sure
val corruptedRecords = sourceDF.filter(col("_corrupt_record").isNotNull)
assert(corruptedRecords.isEmpty, s"Unexpected corrupted records found: ${corruptedRecords.collectAsList()}")

val destDF = StandardizationInterpreter.standardize(sourceDF, baseSchema, cmd.rawFormat)

val actual = destDF.dataAsString(truncate = false)
val expected =
"""+-----+----------+----------+------+
||rowId|reportDate|legs |errCol|
|+-----+----------+----------+------+
||1 |2018-08-10|[[[1000]]]|[] |
||2 |2018-08-10|[[[2000]]]|[] |
||3 |2018-08-10|[[[]]] |[] |
||4 |2018-08-10|null |[] |
|+-----+----------+----------+------+
|
|""".stripMargin.replace("\r\n", "\n")

assert(actual == expected)
}
}

0 comments on commit 603f3ac

Please sign in to comment.