diff --git a/src/main/scala/com/databricks/spark/csv/CsvParser.scala b/src/main/scala/com/databricks/spark/csv/CsvParser.scala index ad04832..7a00646 100644 --- a/src/main/scala/com/databricks/spark/csv/CsvParser.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvParser.scala @@ -42,6 +42,7 @@ class CsvParser extends Serializable { private var codec: String = null private var nullValue: String = "" private var dateFormat: String = null + private var treatParseExceptionAsNull : Boolean = false def withUseHeader(flag: Boolean): CsvParser = { this.useHeader = flag @@ -123,6 +124,16 @@ class CsvParser extends Serializable { this } + /** + * If this is set to true then dirty data, for example a string in a numeric column, + * or a mal-formed date will not cause a failure. + * Instead, that value will be null in the resulting data + */ + def withTreatParseExceptionAsNull(flag : Boolean) : CsvParser = { + this.treatParseExceptionAsNull = flag + this + } + /** Returns a Schema RDD for the given CSV path. */ @throws[RuntimeException] def csvFile(sqlContext: SQLContext, path: String): DataFrame = { @@ -143,7 +154,8 @@ class CsvParser extends Serializable { inferSchema, codec, nullValue, - dateFormat)(sqlContext) + dateFormat, + treatParseExceptionAsNull)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } @@ -165,7 +177,8 @@ class CsvParser extends Serializable { inferSchema, codec, nullValue, - dateFormat)(sqlContext) + dateFormat, + treatParseExceptionAsNull)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } } diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index 3e36931..dce8169 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -49,7 +49,9 @@ case class CsvRelation protected[spark] ( inferCsvSchema: Boolean, codec: String = null, nullValue: String = "", - dateFormat: String = null)(@transient val sqlContext: SQLContext) + dateFormat: String = null, + treatParseExceptionAsNull: Boolean) + (@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with PrunedScan with InsertableRelation { // Share date format object as it is expensive to parse date pattern. @@ -118,7 +120,8 @@ case class CsvRelation protected[spark] ( while (index < schemaFields.length) { val field = schemaFields(index) rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable, - treatEmptyValuesAsNulls, nullValue, simpleDateFormatter) + treatEmptyValuesAsNulls, nullValue, simpleDateFormatter, + treatParseExceptionAsNull) index = index + 1 } Some(Row.fromSeq(rowArray)) @@ -197,7 +200,8 @@ case class CsvRelation protected[spark] ( field.nullable, treatEmptyValuesAsNulls, nullValue, - simpleDateFormatter + simpleDateFormatter, + treatParseExceptionAsNull ) subIndex = subIndex + 1 } diff --git a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala index 220c380..538f114 100755 --- a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala @@ -125,6 +125,16 @@ class DefaultSource throw new Exception("Treat empty values as null flag can be true or false") } + val treatParseExceptionAsNull = parameters.getOrElse( + "treatParseExceptionAsNull", "false") + val treatParseExceptionAsNullFlag = if (treatParseExceptionAsNull == "false"){ + false + } else if (treatParseExceptionAsNull == "true") { + true + } else { + throw new Exception("Treat parse exception as null flag can be true or false") + } + val charset = parameters.getOrElse("charset", TextFile.DEFAULT_CHARSET.name()) // TODO validate charset? @@ -159,7 +169,8 @@ class DefaultSource inferSchemaFlag, codec, nullValue, - dateFormat)(sqlContext) + dateFormat, + treatParseExceptionAsNullFlag)(sqlContext) } override def createRelation( diff --git a/src/main/scala/com/databricks/spark/csv/package.scala b/src/main/scala/com/databricks/spark/csv/package.scala index 0195fb0..2d9c4cd 100755 --- a/src/main/scala/com/databricks/spark/csv/package.scala +++ b/src/main/scala/com/databricks/spark/csv/package.scala @@ -56,7 +56,8 @@ package object csv { ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace, treatEmptyValuesAsNulls = false, - inferCsvSchema = inferSchema)(sqlContext) + inferCsvSchema = inferSchema, + treatParseExceptionAsNull = false)(sqlContext) sqlContext.baseRelationToDataFrame(csvRelation) } @@ -81,7 +82,8 @@ package object csv { ignoreLeadingWhiteSpace = ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace = ignoreTrailingWhiteSpace, treatEmptyValuesAsNulls = false, - inferCsvSchema = inferSchema)(sqlContext) + inferCsvSchema = inferSchema, + treatParseExceptionAsNull = false)(sqlContext) sqlContext.baseRelationToDataFrame(csvRelation) } } diff --git a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala index 11d5a07..f1f78b1 100644 --- a/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala +++ b/src/main/scala/com/databricks/spark/csv/util/TypeCast.scala @@ -21,6 +21,7 @@ import java.text.{SimpleDateFormat, NumberFormat} import java.util.Locale import org.apache.spark.sql.types._ +import org.json4s.ParserUtil.ParseException import scala.util.Try @@ -45,7 +46,8 @@ object TypeCast { nullable: Boolean = true, treatEmptyValuesAsNulls: Boolean = false, nullValue: String = "", - dateFormatter: SimpleDateFormat = null): Any = { + dateFormatter: SimpleDateFormat = null, + parseExceptionAsNull : Boolean = false): Any = { // if nullValue is not an empty string, don't require treatEmptyValuesAsNulls // to be set to true val nullValueIsNotEmpty = nullValue != "" @@ -55,25 +57,34 @@ object TypeCast { ){ null } else { - castType match { - case _: ByteType => datum.toByte - case _: ShortType => datum.toShort - case _: IntegerType => datum.toInt - case _: LongType => datum.toLong - case _: FloatType => Try(datum.toFloat) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) - case _: DoubleType => Try(datum.toDouble) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) - case _: BooleanType => datum.toBoolean - case _: DecimalType => new BigDecimal(datum.replaceAll(",", "")) - case _: TimestampType if dateFormatter != null => - new Timestamp(dateFormatter.parse(datum).getTime) - case _: TimestampType => Timestamp.valueOf(datum) - case _: DateType if dateFormatter != null => - new Date(dateFormatter.parse(datum).getTime) - case _: DateType => Date.valueOf(datum) - case _: StringType => datum - case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + try { + castType match { + case _: ByteType => datum.toByte + case _: ShortType => datum.toShort + case _: IntegerType => datum.toInt + case _: LongType => datum.toLong + case _: FloatType => Try(datum.toFloat) + .getOrElse(NumberFormat + .getInstance(Locale.getDefault).parse(datum).floatValue()) + case _: DoubleType => Try(datum.toDouble) + .getOrElse(NumberFormat.getInstance(Locale.getDefault) + .parse(datum).doubleValue()) + case _: BooleanType => datum.toBoolean + case _: DecimalType => new BigDecimal(datum.replaceAll(",", "")) + case _: TimestampType if dateFormatter != null => + new Timestamp(dateFormatter.parse(datum).getTime) + case _: TimestampType => Timestamp.valueOf(datum) + case _: DateType if dateFormatter != null => + new Date(dateFormatter.parse(datum).getTime) + case _: DateType => Date.valueOf(datum) + case _: StringType => datum + case _ => throw new UnsupportedTypeException(s"Unsupported type: ${castType.typeName}") + } + } + catch { + case e : UnsupportedTypeException => + throw e + case e => if (parseExceptionAsNull && nullable) null else throw e } } } @@ -106,3 +117,6 @@ object TypeCast { } } } + +class UnsupportedTypeException(message: String = null, cause: Throwable = null) + extends RuntimeException(message, cause) diff --git a/src/test/resources/cars_dirty.csv b/src/test/resources/cars_dirty.csv new file mode 100644 index 0000000..d9ce62a --- /dev/null +++ b/src/test/resources/cars_dirty.csv @@ -0,0 +1,5 @@ +year,make,model,price,comment,blank +2012,Tesla,S"80,000.65" +2013.5,Ford,E350,35,000,"Go get one now they are going fast" +2015,,Volt,5,000 +new,"",Volt,5000.00 diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index b3cd644..f0dd092 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -32,6 +32,7 @@ import org.scalatest.Matchers._ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { val carsFile = "src/test/resources/cars.csv" val carsMalformedFile = "src/test/resources/cars-malformed.csv" + val carsDirtyTsvFile = "src/test/resources/cars_dirty.csv" val carsFile8859 = "src/test/resources/cars_iso-8859-1.csv" val carsTsvFile = "src/test/resources/cars.tsv" val carsAltFile = "src/test/resources/cars-alternative.csv" @@ -67,6 +68,12 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { super.afterAll() } } + test("Dirty Data CSV"){ + val results = sqlContext.csvFile( + carsDirtyTsvFile, parserLib = parserLib + ).collect() + assert(results.length == 4) + } test("DSL test") { val results = sqlContext @@ -197,9 +204,44 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { .select("Name") .collect().size + val r = new CsvParser() + .withSchema(strictSchema) + .withUseHeader(true) + .withParserLib(parserLib) + .withParseMode(ParseModes.DROP_MALFORMED_MODE) + .csvFile(sqlContext, ageFile) + .select("Name") + .collect() + assert(results === 1) } + test("Parse Exception with Schema"){ + val carsSchema = new StructType( + Array( + StructField("year", IntegerType, nullable = true), + StructField("make", StringType, nullable = true), + StructField("model", StringType, nullable = true), + StructField("price", DoubleType, nullable = true), + StructField("comment", StringType, nullable = true), + StructField("blank", IntegerType, nullable = true) + ) + ) + + val results = new CsvParser() + .withSchema(carsSchema) + .withUseHeader(true) + .withDelimiter(',') + .withQuoteChar('\"').withTreatParseExceptionAsNull(true) + .csvFile(sqlContext, carsDirtyTsvFile).select("year", "make") + .collect() + + assert(results(0).toSeq == Seq(2012, "Tesla")) + assert(results(1).toSeq == Seq(null, "Ford")) + assert(results(2).toSeq == Seq(2015, "")) + assert(results(3).toSeq == Seq(null, "")) + } + test("DSL test for FAILFAST parsing mode") { val parser = new CsvParser() .withParseMode(ParseModes.FAIL_FAST_MODE) @@ -267,6 +309,7 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { .withDelimiter(',') .withQuoteChar(null) .withUseHeader(true) + .withNullValue("") .withParserLib(parserLib) .csvFile(sqlContext, carsUnbalancedQuotesFile) .select("year") @@ -677,7 +720,7 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { assert(results.schema == StructType(List( StructField("year", IntegerType, nullable = true), StructField("make", StringType, nullable = true), - StructField("model", StringType ,nullable = true), + StructField("model", StringType, nullable = true), StructField("comment", StringType, nullable = true), StructField("blank", StringType, nullable = true)) )) diff --git a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala index 448debf..52c2a06 100644 --- a/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/util/TypeCastSuite.scala @@ -24,6 +24,8 @@ import org.scalatest.FunSuite import org.apache.spark.sql.types._ +import scala.util.Try + class TypeCastSuite extends FunSuite { test("Can parse decimal type values") { @@ -115,4 +117,48 @@ class TypeCastSuite extends FunSuite { assert(TypeCast.castTo("", StringType, true, false, "") == "") assert(TypeCast.castTo("", StringType, true, true, "") == null) } + + test("Parse exception is caught correctly"){ + + def testParseException( castType : DataType, badValues : Seq[String]): Unit = { + badValues.foreach(testValue => { + assert(TypeCast.castTo(testValue, castType, true, false, "", null, true) == null) + // if not nullable it isn't null + assert(Try(TypeCast.castTo(testValue, castType, false, false, "", null, true)).isFailure) + } + ) + } + + assert(TypeCast.castTo("10", ByteType, true, false, "", null, true) == 10) + testParseException(ByteType, Seq("10.5", "s", "true")) + + assert(TypeCast.castTo("10", ShortType, true, false, "", null, true) == 10) + testParseException(ShortType, Seq("s", "true")) + + assert(TypeCast.castTo("10", IntegerType, true, false, "", null, true) == 10) + testParseException(IntegerType, Seq("10.5", "s", "true")) + + assert(TypeCast.castTo("10", LongType, true, false, "", null, true) == 10) + testParseException(LongType, Seq("10.5", "s", "true")) + + assert(TypeCast.castTo("1.00", FloatType, true, false, "", null, true) == 1.0) + testParseException(FloatType, Seq("s", "true")) + + assert(TypeCast.castTo("1.00", DoubleType, true, false, "", null, true) == 1.0) + testParseException(DoubleType, Seq("s", "true")) + + assert(TypeCast.castTo("true", BooleanType, true, false, "", null, true) == true) + testParseException(BooleanType, Seq("s", "5")) + + val timestamp = "2015-01-01 00:00:00" + assert(TypeCast.castTo(timestamp, TimestampType, true, false, "", null, true) + == Timestamp.valueOf(timestamp)) + testParseException(TimestampType, Seq("5", "string")) + + assert(TypeCast.castTo("2015-01-01", DateType, true, false, "", null, true) + == Date.valueOf("2015-01-01")) + testParseException(DateType, Seq("5", "string", timestamp)) + } + + }