-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-14139][SQL] RowEncoder should preserve schema nullability #12364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
847d7c7
b500c8b
3e32b6a
15adc9c
2600a2e
50c0005
05e5f19
4b140e5
e9a9a30
8231a15
2976d26
7a1877a
ad8c9ef
4ee75e9
ded8800
645f0a0
7c4c91c
8870650
98e1463
18aa126
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,9 +35,8 @@ import org.apache.spark.unsafe.types.UTF8String | |
| object RowEncoder { | ||
| def apply(schema: StructType): ExpressionEncoder[Row] = { | ||
| val cls = classOf[Row] | ||
| val inputObject = BoundReference(0, ObjectType(cls), nullable = true) | ||
| // We use an If expression to wrap extractorsFor result of StructType | ||
| val serializer = serializerFor(inputObject, schema).asInstanceOf[If].falseValue | ||
| val inputObject = BoundReference(0, ObjectType(cls), nullable = false) | ||
| val serializer = serializerFor(inputObject, schema) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is also because we don't allow null input
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea, there is no if anymore for the top row object. |
||
| val deserializer = deserializerFor(schema) | ||
| new ExpressionEncoder[Row]( | ||
| schema, | ||
|
|
@@ -130,21 +129,28 @@ object RowEncoder { | |
|
|
||
| case StructType(fields) => | ||
| val convertedFields = fields.zipWithIndex.map { case (f, i) => | ||
| val method = if (f.dataType.isInstanceOf[StructType]) { | ||
| "getStruct" | ||
| val fieldValue = serializerFor( | ||
| GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType)), | ||
| f.dataType | ||
| ) | ||
| if (f.nullable) { | ||
| If( | ||
| Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), | ||
| Literal.create(null, f.dataType), | ||
| fieldValue | ||
| ) | ||
| } else { | ||
| "get" | ||
| fieldValue | ||
| } | ||
| If( | ||
| Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), | ||
| Literal.create(null, f.dataType), | ||
| serializerFor( | ||
| Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil), | ||
| f.dataType)) | ||
| } | ||
| If(IsNull(inputObject), | ||
| Literal.create(null, inputType), | ||
| CreateStruct(convertedFields)) | ||
|
|
||
| if (inputObject.nullable) { | ||
| If(IsNull(inputObject), | ||
| Literal.create(null, inputType), | ||
| CreateStruct(convertedFields)) | ||
| } else { | ||
| CreateStruct(convertedFields) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -688,3 +688,45 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) | |
| ev.copy(code = code, isNull = "false", value = childGen.value) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Returns the value of field at index `index` from the external row `child`. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be nice to add a more intuitive description:
It took me some time to realize this... |
||
| * This class can be viewed as [[GetStructField]] for [[Row]]s instead of [[InternalRow]]s. | ||
| * | ||
| * Note that the input row and the field we try to get are both guaranteed to be not null, if they | ||
| * are null, a runtime exception will be thrown. | ||
| */ | ||
| case class GetExternalRowField( | ||
| child: Expression, | ||
| index: Int, | ||
| dataType: DataType) extends UnaryExpression with NonSQLExpression { | ||
|
|
||
| override def nullable: Boolean = false | ||
|
|
||
| override def eval(input: InternalRow): Any = | ||
| throw new UnsupportedOperationException("Only code-generated evaluation is supported") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not related to this PR. We might want to add a |
||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| val row = child.genCode(ctx) | ||
|
|
||
| val getField = dataType match { | ||
| case ObjectType(x) if x == classOf[Row] => s"""${row.value}.getStruct($index)""" | ||
| case _ => s"""(${ctx.boxedType(dataType)}) ${row.value}.get($index)""" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about using
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't. For object type, |
||
| } | ||
|
|
||
| val code = s""" | ||
| ${row.code} | ||
|
|
||
| if (${row.isNull}) { | ||
| throw new RuntimeException("The input external row cannot be null."); | ||
| } | ||
|
|
||
| if (${row.value}.isNullAt($index)) { | ||
| throw new RuntimeException("The ${index}th field of input row cannot be null."); | ||
| } | ||
|
|
||
| final ${ctx.javaType(dataType)} ${ev.value} = $getField; | ||
| """ | ||
| ev.copy(code = code, isNull = "false") | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -160,6 +160,14 @@ class RowEncoderSuite extends SparkFunSuite { | |
| .compareTo(convertedBack.getDecimal(3)) == 0) | ||
| } | ||
|
|
||
| test("RowEncoder should preserve schema nullability") { | ||
| val schema = new StructType().add("int", IntegerType, nullable = false) | ||
| val encoder = RowEncoder(schema) | ||
| assert(encoder.serializer.length == 1) | ||
| assert(encoder.serializer.head.dataType == IntegerType) | ||
| assert(encoder.serializer.head.nullable == false) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will we throw an exception if there is a null in the data?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good point, actually we will, we should add the runtime null check like we did for product encoder |
||
| } | ||
|
|
||
| private def encodeDecodeTest(schema: StructType): Unit = { | ||
| test(s"encode/decode: ${schema.simpleString}") { | ||
| val encoder = RowEncoder(schema) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reason of this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input object should never be null, we also use this assumption in
ExpressionEncoderThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually
ExpressionEncoderallows null input now. But I agree that forRowEncoderthis assumption is reasonable.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably add a comment here. It's not super intuitive.