Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -304,15 +304,7 @@ trait Row extends Serializable {
*
* @throws ClassCastException when data type does not match.
*/
def getStruct(i: Int): Row = {
// Product and Row both are recognized as StructType in a Row
val t = get(i)
if (t.isInstanceOf[Product]) {
Row.fromTuple(t.asInstanceOf[Product])
} else {
t.asInstanceOf[Row]
}
}
def getStruct(i: Int): Row = getAs[Row](i)

/**
* Returns the value at position i.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import org.apache.spark.unsafe.types.UTF8String
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq or Array
* MapType -> scala.collection.Map
* StructType -> org.apache.spark.sql.Row or Product
* StructType -> org.apache.spark.sql.Row
* }}}
*/
object RowEncoder {
Expand Down Expand Up @@ -121,11 +121,15 @@ object RowEncoder {

case t @ ArrayType(et, _) => et match {
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
// TODO: validate input type for primitive array.
NewInstance(
classOf[GenericArrayData],
inputObject :: Nil,
dataType = t)
case _ => MapObjects(serializerFor(_, et), inputObject, externalDataTypeForInput(et))
case _ => MapObjects(
element => serializerFor(ValidateExternalType(element, et), et),
inputObject,
ObjectType(classOf[Object]))
}

case t @ MapType(kt, vt, valueNullable) =>
Expand All @@ -151,8 +155,9 @@ object RowEncoder {
case StructType(fields) =>
val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) =>
val fieldValue = serializerFor(
GetExternalRowField(
inputObject, index, field.name, externalDataTypeForInput(field.dataType)),
ValidateExternalType(
GetExternalRowField(inputObject, index, field.name),
field.dataType),
field.dataType)
val convertedField = if (field.nullable) {
If(
Expand Down Expand Up @@ -183,7 +188,7 @@ object RowEncoder {
* can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or
* `org.apache.spark.sql.types.Decimal`.
*/
private def externalDataTypeForInput(dt: DataType): DataType = dt match {
def externalDataTypeForInput(dt: DataType): DataType = dt match {
// In order to support both Decimal and java/scala BigDecimal in external row, we make this
// as java.lang.Object.
case _: DecimalType => ObjectType(classOf[java.lang.Object])
Expand All @@ -192,7 +197,7 @@ object RowEncoder {
case _ => externalDataTypeFor(dt)
}

private def externalDataTypeFor(dt: DataType): DataType = dt match {
def externalDataTypeFor(dt: DataType): DataType = dt match {
case _ if ScalaReflection.isNativeType(dt) => dt
case TimestampType => ObjectType(classOf[java.sql.Timestamp])
case DateType => ObjectType(classOf[java.sql.Date])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.serializer._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.GenericArrayData
Expand Down Expand Up @@ -692,22 +693,17 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
case class GetExternalRowField(
child: Expression,
index: Int,
fieldName: String,
dataType: DataType) extends UnaryExpression with NonSQLExpression {
fieldName: String) extends UnaryExpression with NonSQLExpression {

override def nullable: Boolean = false

override def dataType: DataType = ObjectType(classOf[Object])

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val row = child.genCode(ctx)

val getField = dataType match {
case ObjectType(x) if x == classOf[Row] => s"""${row.value}.getStruct($index)"""
case _ => s"""(${ctx.boxedType(dataType)}) ${row.value}.get($index)"""
}

val code = s"""
${row.code}

Expand All @@ -720,8 +716,55 @@ case class GetExternalRowField(
"cannot be null.");
}

final ${ctx.javaType(dataType)} ${ev.value} = $getField;
final Object ${ev.value} = ${row.value}.get($index);
"""
ev.copy(code = code, isNull = "false")
}
}

/**
* Validates the actual data type of input expression at runtime. If it doesn't match the
* expectation, throw an exception.
*/
case class ValidateExternalType(child: Expression, expected: DataType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously we completely rely on the correctness of provided schema. If the data doesn't match provided schema, we should see value converting error. I think this ValidateExternalType can improve the experience of error handling. However, it looks like a performance regression as we do extra checking here. If we can trust the provided schema, should we do this checking?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is we can't trust it.... When users call createDataFrame(rows, schema), we should definitely validate the passed-in rows. I think performance doesn't matter too much here, as this only happens at the beginning of the data flow. One potential issue may be that, Dataset.map can return row and users will provide a schema we should trust. However, I don't think we should expose RowEncoder to users and Dataset.map should never return a row.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, for the trust, it should be said that we leave the responsibility of data correctness to users. Actually when the data is in wrong type, we will not get wrong result. There is of course an exception regarding data converting will be thrown. However, better error handling is always good if as you said the performance is not big issue here.

extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(ObjectType(classOf[Object]))

override def nullable: Boolean = child.nullable

override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected)

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val input = child.genCode(ctx)
val obj = input.value

val typeCheck = expected match {
case _: DecimalType =>
Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal])
.map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ")
case _: ArrayType =>
s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()"
case _ =>
s"$obj instanceof ${ctx.boxedType(dataType)}"
}

val code = s"""
${input.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.boxedType here too?

Copy link
Contributor Author

@cloud-fan cloud-fan May 31, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional, we can't cast an object to int directly, but have to cast to boxed int first.

if (!${input.isNull}) {
if ($typeCheck) {
${ev.value} = (${ctx.boxedType(dataType)}) $obj;
} else {
throw new RuntimeException($obj.getClass().getName() + " is not a valid " +
"external type for schema of ${expected.simpleString}");
}
}

"""
ev.copy(code = code, isNull = input.isNull)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,6 @@ class RowEncoderSuite extends SparkFunSuite {
new StructType().add("array", arrayOfString).add("map", mapOfString))
.add("structOfUDT", structOfUDT))

test(s"encode/decode: Product") {
val schema = new StructType()
.add("structAsProduct",
new StructType()
.add("int", IntegerType)
.add("string", StringType)
.add("double", DoubleType))

val encoder = RowEncoder(schema).resolveAndBind()

val input: Row = Row((100, "test", 0.123))
val row = encoder.toRow(input)
val convertedBack = encoder.fromRow(row)
assert(input.getStruct(0) == convertedBack.getStruct(0))
}

test("encode/decode decimal type") {
val schema = new StructType()
.add("int", IntegerType)
Expand Down Expand Up @@ -232,6 +216,37 @@ class RowEncoderSuite extends SparkFunSuite {
assert(e.getMessage.contains("top level row object"))
}

test("RowEncoder should validate external type") {
val e1 = intercept[RuntimeException] {
val schema = new StructType().add("a", IntegerType)
val encoder = RowEncoder(schema)
encoder.toRow(Row(1.toShort))
}
assert(e1.getMessage.contains("java.lang.Short is not a valid external type"))

val e2 = intercept[RuntimeException] {
val schema = new StructType().add("a", StringType)
val encoder = RowEncoder(schema)
encoder.toRow(Row(1))
}
assert(e2.getMessage.contains("java.lang.Integer is not a valid external type"))

val e3 = intercept[RuntimeException] {
val schema = new StructType().add("a",
new StructType().add("b", IntegerType).add("c", StringType))
val encoder = RowEncoder(schema)
encoder.toRow(Row(1 -> "a"))
}
assert(e3.getMessage.contains("scala.Tuple2 is not a valid external type"))

val e4 = intercept[RuntimeException] {
val schema = new StructType().add("a", ArrayType(TimestampType))
val encoder = RowEncoder(schema)
encoder.toRow(Row(Array("a")))
}
assert(e4.getMessage.contains("java.lang.String is not a valid external type"))
}

private def encodeDecodeTest(schema: StructType): Unit = {
test(s"encode/decode: ${schema.simpleString}") {
val encoder = RowEncoder(schema).resolveAndBind()
Expand Down