Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ import org.apache.spark.unsafe.types.UTF8String
object RowEncoder {
def apply(schema: StructType): ExpressionEncoder[Row] = {
val cls = classOf[Row]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
// We use an If expression to wrap extractorsFor result of StructType
val serializer = serializerFor(inputObject, schema).asInstanceOf[If].falseValue
val inputObject = BoundReference(0, ObjectType(cls), nullable = false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reason of this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input object should never be null, we also use this assumption in ExpressionEncoder

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually ExpressionEncoder allows null input now. But I agree that for RowEncoder this assumption is reasonable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably add a comment here. It's not super intuitive.

val serializer = serializerFor(inputObject, schema)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is also because we don't allow null input Rows now, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, there is no if anymore for the top row object.

val deserializer = deserializerFor(schema)
new ExpressionEncoder[Row](
schema,
Expand Down Expand Up @@ -130,21 +129,28 @@ object RowEncoder {

case StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
val method = if (f.dataType.isInstanceOf[StructType]) {
"getStruct"
val fieldValue = serializerFor(
GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType)),
f.dataType
)
if (f.nullable) {
If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, f.dataType),
fieldValue
)
} else {
"get"
fieldValue
}
If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, f.dataType),
serializerFor(
Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil),
f.dataType))
}
If(IsNull(inputObject),
Literal.create(null, inputType),
CreateStruct(convertedFields))

if (inputObject.nullable) {
If(IsNull(inputObject),
Literal.create(null, inputType),
CreateStruct(convertedFields))
} else {
CreateStruct(convertedFields)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -688,3 +688,45 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
ev.copy(code = code, isNull = "false", value = childGen.value)
}
}

/**
* Returns the value of field at index `index` from the external row `child`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to add a more intuitive description:

This class can be viewed as [[GetStructField]] for [[Row]]s instead of [[InternalRow]]s.

It took me some time to realize this...

* This class can be viewed as [[GetStructField]] for [[Row]]s instead of [[InternalRow]]s.
*
* Note that the input row and the field we try to get are both guaranteed to be not null, if they
* are null, a runtime exception will be thrown.
*/
case class GetExternalRowField(
child: Expression,
index: Int,
dataType: DataType) extends UnaryExpression with NonSQLExpression {

override def nullable: Boolean = false

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR. We might want to add a CodegenOnly trait extending from Expression to avoid duplicating this one.


override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val row = child.genCode(ctx)

val getField = dataType match {
case ObjectType(x) if x == classOf[Row] => s"""${row.value}.getStruct($index)"""
case _ => s"""(${ctx.boxedType(dataType)}) ${row.value}.get($index)"""
Copy link
Contributor

@liancheng liancheng May 5, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using CodegenContext.getValue() here to generate specialized code to avoid boxing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't. For object type, CodegenContext.getValue() will generate row.get(ordinal, null), which should be row.get(ordinal) for external row.

}

val code = s"""
${row.code}

if (${row.isNull}) {
throw new RuntimeException("The input external row cannot be null.");
}

if (${row.value}.isNullAt($index)) {
throw new RuntimeException("The ${index}th field of input row cannot be null.");
}

final ${ctx.javaType(dataType)} ${ev.value} = $getField;
"""
ev.copy(code = code, isNull = "false")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ class RowEncoderSuite extends SparkFunSuite {
.compareTo(convertedBack.getDecimal(3)) == 0)
}

test("RowEncoder should preserve schema nullability") {
val schema = new StructType().add("int", IntegerType, nullable = false)
val encoder = RowEncoder(schema)
assert(encoder.serializer.length == 1)
assert(encoder.serializer.head.dataType == IntegerType)
assert(encoder.serializer.head.nullable == false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we throw an exception if there is a null in the data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point, actually we will, we should add the runtime null check like we did for product encoder

}

private def encodeDecodeTest(schema: StructType): Unit = {
test(s"encode/decode: ${schema.simpleString}") {
val encoder = RowEncoder(schema)
Expand Down
18 changes: 17 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}

import scala.language.postfixOps

import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
Expand Down Expand Up @@ -658,6 +658,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val dataset = Seq(1, 2, 3).toDS()
checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4)
}

test("runtime null check for RowEncoder") {
val schema = new StructType().add("i", IntegerType, nullable = false)
val df = sqlContext.range(10).map(l => {
if (l % 5 == 0) {
Row(null)
} else {
Row(l)
}
})(RowEncoder(schema))

val message = intercept[Exception] {
df.collect()
}.getMessage
assert(message.contains("The 0th field of input row cannot be null"))
}
}

case class OtherTuple(_1: String, _2: Int)
Expand Down