Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.datasources.csv

import java.math.BigDecimal
import java.sql.{Date, Timestamp}
import java.text.NumberFormat
import java.util.Locale

Expand All @@ -27,7 +26,9 @@ import scala.util.Try

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

private[csv] object CSVInferSchema {

Expand Down Expand Up @@ -116,7 +117,7 @@ private[csv] object CSVInferSchema {
}

def tryParseTimestamp(field: String): DataType = {
if ((allCatch opt Timestamp.valueOf(field)).isDefined) {
if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) {
TimestampType
} else {
tryParseBoolean(field)
Expand Down Expand Up @@ -191,12 +192,18 @@ private[csv] object CSVTypeCast {
case _: DoubleType => Try(datum.toDouble)
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue())
case _: BooleanType => datum.toBoolean
case _: DecimalType => new BigDecimal(datum.replaceAll(",", ""))
case dt: DecimalType =>
val value = new BigDecimal(datum.replaceAll(",", ""))
Decimal(value, dt.precision, dt.scale)
// TODO(hossein): would be good to support other common timestamp formats
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this todo?

case _: TimestampType => Timestamp.valueOf(datum)
case _: TimestampType =>
// This one will lose microseconds parts.
// See https://issues.apache.org/jira/browse/SPARK-10681.
DateTimeUtils.stringToTime(datum).getTime * 1000L
// TODO(hossein): would be good to support other common date formats
case _: DateType => Date.valueOf(datum)
case _: StringType => datum
case _: DateType =>
DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime)
case _: StringType => UTF8String.fromString(datum)
case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._

Expand All @@ -54,7 +55,7 @@ object CSVRelation extends Logging {
requiredColumns: Array[String],
inputs: Seq[FileStatus],
sqlContext: SQLContext,
params: CSVOptions): RDD[Row] = {
params: CSVOptions): RDD[InternalRow] = {

val schemaFields = schema.fields
val requiredFields = StructType(requiredColumns.map(schema(_))).fields
Expand All @@ -71,8 +72,8 @@ object CSVRelation extends Logging {
}.foreach {
case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index
}
val rowArray = new Array[Any](safeRequiredIndices.length)
val requiredSize = requiredFields.length
val row = new GenericMutableRow(requiredSize)
tokenizedRDD.flatMap { tokens =>
if (params.dropMalformed && schemaFields.length != tokens.length) {
logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
Expand All @@ -94,14 +95,20 @@ object CSVRelation extends Logging {
while (subIndex < safeRequiredIndices.length) {
index = safeRequiredIndices(subIndex)
val field = schemaFields(index)
rowArray(subIndex) = CSVTypeCast.castTo(
// It anyway needs to try to parse since it decides if this row is malformed
// or not after trying to cast in `DROPMALFORMED` mode even if the casted
// value is not stored in the row.
val value = CSVTypeCast.castTo(
indexSafeTokens(index),
field.dataType,
field.nullable,
params.nullValue)
if (subIndex < requiredSize) {
row(subIndex) = value
}
subIndex = subIndex + 1
}
Some(Row.fromSeq(rowArray.take(requiredSize)))
Some(row)
} catch {
case NonFatal(e) if params.dropMalformed =>
logWarning("Parse exception. " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.execution.datasources.CompressionCodecs
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructField, StructType}
Expand Down Expand Up @@ -113,13 +113,14 @@ class DefaultSource extends FileFormat with DataSourceRegister {
val pathsString = csvFiles.map(_.getPath.toUri.toString)
val header = dataSchema.fields.map(_.name)
val tokenizedRdd = tokenRdd(sqlContext, csvOptions, header, pathsString)
val external = CSVRelation.parseCsv(
val rows = CSVRelation.parseCsv(
tokenizedRdd, dataSchema, requiredColumns, csvFiles, sqlContext, csvOptions)

// TODO: Generate InternalRow in parseCsv
val outputSchema = StructType(requiredColumns.map(c => dataSchema.find(_.name == c).get))
val encoder = RowEncoder(outputSchema)
external.map(encoder.toRow)
val requiredDataSchema = StructType(requiredColumns.map(c => dataSchema.find(_.name == c).get))
rows.mapPartitions { iterator =>
val unsafeProjection = UnsafeProjection.create(requiredDataSchema)
iterator.map(unsafeProjection)
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
package org.apache.spark.sql.execution.datasources.csv

import java.math.BigDecimal
import java.sql.{Date, Timestamp}
import java.util.Locale

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class CSVTypeCastSuite extends SparkFunSuite {

Expand All @@ -32,7 +33,9 @@ class CSVTypeCastSuite extends SparkFunSuite {
val decimalType = new DecimalType()

stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) =>
assert(CSVTypeCast.castTo(strVal, decimalType) === new BigDecimal(decimalVal.toString))
val decimalValue = new BigDecimal(decimalVal.toString)
assert(CSVTypeCast.castTo(strVal, decimalType) ===
Decimal(decimalValue, decimalType.precision, decimalType.scale))
}
}

Expand Down Expand Up @@ -65,8 +68,8 @@ class CSVTypeCastSuite extends SparkFunSuite {
}

test("String type should always return the same as the input") {
assert(CSVTypeCast.castTo("", StringType, nullable = true) == "")
assert(CSVTypeCast.castTo("", StringType, nullable = false) == "")
assert(CSVTypeCast.castTo("", StringType, nullable = true) == UTF8String.fromString(""))
assert(CSVTypeCast.castTo("", StringType, nullable = false) == UTF8String.fromString(""))
}

test("Throws exception for empty string with non null type") {
Expand All @@ -85,8 +88,10 @@ class CSVTypeCastSuite extends SparkFunSuite {
assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0)
assert(CSVTypeCast.castTo("true", BooleanType) == true)
val timestamp = "2015-01-01 00:00:00"
assert(CSVTypeCast.castTo(timestamp, TimestampType) == Timestamp.valueOf(timestamp))
assert(CSVTypeCast.castTo("2015-01-01", DateType) == Date.valueOf("2015-01-01"))
assert(CSVTypeCast.castTo(timestamp, TimestampType) ==
DateTimeUtils.stringToTime(timestamp).getTime * 1000L)
assert(CSVTypeCast.castTo("2015-01-01", DateType) ==
DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime))
}

test("Float and Double Types are cast correctly with Locale") {
Expand Down