From 847d7c7cdfe6626ea1f73656f9eaf868d641ae1c Mon Sep 17 00:00:00 2001 From: Koert Kuipers Date: Sat, 26 Mar 2016 16:14:18 -0400 Subject: [PATCH 01/11] change RowEncoder to respect nullable for struct fields when generating extractors --- .../sql/catalyst/encoders/RowEncoder.scala | 18 ++++++++++++------ .../sql/catalyst/expressions/objects.scala | 4 ++-- .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++++++ 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 902644e735ea..f8615deb5f66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -125,12 +125,18 @@ object RowEncoder { } else { "get" } - If( - Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), - Literal.create(null, f.dataType), - extractorsFor( - Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil), - f.dataType)) + val x = extractorsFor( + Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil, + f.nullable), + f.dataType) + if (f.nullable) { + If( + Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), + Literal.create(null, f.dataType), + x) + } else { + x + } } If(IsNull(inputObject), Literal.create(null, inputType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 7eba617fcde5..08008ab907dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -109,9 +109,9 @@ case class Invoke( targetObject: Expression, functionName: String, dataType: DataType, - arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression { + arguments: Seq[Expression] = Nil, + val nullable: Boolean = true) extends Expression with NonSQLExpression { - override def nullable: Boolean = true override def children: Seq[Expression] = arguments.+:(targetObject) override def eval(input: InternalRow): Any = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ec4e7b2042bc..72850eaab52c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.exchange.{BroadcastExchange, ReusedExchange, ShuffleExchange} @@ -1432,4 +1433,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { getMessage() assert(e1.startsWith("Path does not exist")) } + + test("SPARK-14139: map on row and preserve schema nullability") { + val df1 = Seq(1, 2, 3).toDF + assert(df1.map(row => Row(row.getInt(0) + 1))(RowEncoder(df1.schema)).schema === df1.schema) + } } From 3e32b6aa4dbc0adcdd892ee838ccaed77a67dc58 Mon Sep 17 00:00:00 2001 From: Koert Kuipers Date: Sat, 26 Mar 2016 17:05:11 -0400 Subject: [PATCH 02/11] make scalastyle happy --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 633f23855f77..953b3e122223 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -26,8 +26,8 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.exchange.{BroadcastExchange, ReusedExchange, ShuffleExchange} From 2600a2e28e56f580db512bc7474cdfd1fd17df58 Mon Sep 17 00:00:00 2001 From: Koert Kuipers Date: Tue, 29 Mar 2016 15:09:58 -0400 Subject: [PATCH 03/11] first try at using GetExternalRowField instead of Invoke. fails unit test RowEncoderSuite:encode/decode:Product --- .../sql/catalyst/encoders/RowEncoder.scala | 14 ++--- .../sql/catalyst/expressions/objects.scala | 55 ++++++++++++++++++- 2 files changed, 58 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 1f2afb6ba86d..7096b5788091 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -120,20 +120,16 @@ object RowEncoder { case StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => - val method = if (f.dataType.isInstanceOf[StructType]) { - "getStruct" - } else { - "get" - } val x = extractorsFor( - Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil, - f.nullable), - f.dataType) + GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType), f.nullable), + f.dataType + ) if (f.nullable) { If( Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, f.dataType), - x) + x + ) } else { x } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 3ada3b0ce236..e4c530d92d3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -109,9 +109,9 @@ case class Invoke( targetObject: Expression, functionName: String, dataType: DataType, - arguments: Seq[Expression] = Nil, - val nullable: Boolean = true) extends Expression with NonSQLExpression { + arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression { + override def nullable: Boolean = true override def children: Seq[Expression] = arguments.+:(targetObject) override def eval(input: InternalRow): Any = @@ -680,3 +680,54 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) """ } } + +case class GetExternalRowField( + targetObject: Expression, + index: Int, + dataType: DataType, + nullable: Boolean) extends Expression with NonSQLExpression { + + override def children: Seq[Expression] = Seq(targetObject) + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + val javaType = ctx.javaType(dataType) + val obj = targetObject.gen(ctx) + + val get = dataType match { + case IntegerType => s"""${obj.value}.getInt($index)""" + case LongType => s"""${obj.value}.getLong($index)""" + case FloatType => s"""${obj.value}.getFloat($index)""" + case ShortType => s"""${obj.value}.getShort($index)""" + case ByteType => s"""${obj.value}.getByte($index)""" + case DoubleType => s"""${obj.value}.getDouble($index)""" + case BooleanType => s"""${obj.value}.getBoolean($index)""" + case _: StructType => s"""${obj.value}.getStruct($index)""" + case _ => s"""((${javaType}) ${obj.value}.get($index))""" + } + + val code = if (nullable) { + s""" + ${obj.code} + final ${javaType} ${ev.value}; + final boolean ${ev.isNull}; + if (${obj.value}.isNullAt(${index})) { + ${ev.value} = ${ctx.defaultValue(dataType)}; + ${ev.isNull} = true; + } else { + ${ev.value} = ${get}; + ${ev.isNull} = false; + } + """ + } else { + s""" + ${obj.code} + final ${javaType} ${ev.value} = ${get}; + final boolean ${ev.isNull} = false; + """ + } + code + } +} From 4b140e5d2f3efa1d4866ec41032d5f1704b6f332 Mon Sep 17 00:00:00 2001 From: Koert Kuipers Date: Sun, 3 Apr 2016 16:03:52 -0400 Subject: [PATCH 04/11] fix for extractsFor renamed to serializerFor --- .../org/apache/spark/sql/catalyst/encoders/RowEncoder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 849cd4fa1098..becfc4f91e0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -120,7 +120,7 @@ object RowEncoder { case StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => - val x = extractorsFor( + val x = serializerFor( GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType), f.nullable), f.dataType ) From e9a9a30e1804785d3534bea78cf2ce588f7fc51b Mon Sep 17 00:00:00 2001 From: Koert Kuipers Date: Sun, 3 Apr 2016 16:05:20 -0400 Subject: [PATCH 05/11] fix pattern match in GetExternalRowField where the data types are external --- .../org/apache/spark/sql/catalyst/expressions/objects.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index e4c530d92d3d..13b93cfe24c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -704,11 +704,11 @@ case class GetExternalRowField( case ByteType => s"""${obj.value}.getByte($index)""" case DoubleType => s"""${obj.value}.getDouble($index)""" case BooleanType => s"""${obj.value}.getBoolean($index)""" - case _: StructType => s"""${obj.value}.getStruct($index)""" + case ObjectType(x) if x == classOf[Row] => s"""${obj.value}.getStruct($index)""" case _ => s"""((${javaType}) ${obj.value}.get($index))""" } - val code = if (nullable) { + if (nullable) { s""" ${obj.code} final ${javaType} ${ev.value}; @@ -728,6 +728,5 @@ case class GetExternalRowField( final boolean ${ev.isNull} = false; """ } - code } } From 2976d26d5de2870541213102a5179b99f0335e94 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 14 Apr 2016 00:56:37 +0800 Subject: [PATCH 06/11] some simplification --- .../sql/catalyst/encoders/RowEncoder.scala | 6 +- .../sql/catalyst/expressions/objects.scala | 62 +++++++------------ .../catalyst/encoders/RowEncoderSuite.scala | 8 +++ .../org/apache/spark/sql/DataFrameSuite.scala | 6 -- 4 files changed, 35 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index becfc4f91e0b..84d1897e46ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -120,7 +120,7 @@ object RowEncoder { case StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => - val x = serializerFor( + val fieldValue = serializerFor( GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType), f.nullable), f.dataType ) @@ -128,10 +128,10 @@ object RowEncoder { If( Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, f.dataType), - x + fieldValue ) } else { - x + fieldValue } } If(IsNull(inputObject), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 6525573cf1db..07022cdda0d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -709,52 +709,38 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) } } +/** + * Returns the value of field at index `index` from the external row `child`. + */ case class GetExternalRowField( - targetObject: Expression, + child: Expression, index: Int, dataType: DataType, - nullable: Boolean) extends Expression with NonSQLExpression { - - override def children: Seq[Expression] = Seq(targetObject) + override val nullable: Boolean) extends UnaryExpression with NonSQLExpression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val javaType = ctx.javaType(dataType) - val obj = targetObject.gen(ctx) - - val get = dataType match { - case IntegerType => s"""${obj.value}.getInt($index)""" - case LongType => s"""${obj.value}.getLong($index)""" - case FloatType => s"""${obj.value}.getFloat($index)""" - case ShortType => s"""${obj.value}.getShort($index)""" - case ByteType => s"""${obj.value}.getByte($index)""" - case DoubleType => s"""${obj.value}.getDouble($index)""" - case BooleanType => s"""${obj.value}.getBoolean($index)""" - case ObjectType(x) if x == classOf[Row] => s"""${obj.value}.getStruct($index)""" - case _ => s"""((${javaType}) ${obj.value}.get($index))""" - } + nullSafeCodeGen(ctx, ev, eval => { + val getField = dataType match { + case ObjectType(x) if x == classOf[Row] => s"""$eval.getStruct($index)""" + case _ => s"""((${ctx.boxedType(dataType)}) $eval.get($index))""" + } - if (nullable) { - s""" - ${obj.code} - final ${javaType} ${ev.value}; - final boolean ${ev.isNull}; - if (${obj.value}.isNullAt(${index})) { - ${ev.value} = ${ctx.defaultValue(dataType)}; - ${ev.isNull} = true; - } else { - ${ev.value} = ${get}; - ${ev.isNull} = false; - } - """ - } else { - s""" - ${obj.code} - final ${javaType} ${ev.value} = ${get}; - final boolean ${ev.isNull} = false; - """ - } + if (nullable) { + s""" + if ($eval.isNullAt($index)) { + ${ev.isNull} = true; + } else { + ${ev.value} = $getField; + } + """ + } else { + s""" + ${ev.value} = $getField; + """ + } + }) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index a8fa372b1ee3..98be3b053d5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -160,6 +160,14 @@ class RowEncoderSuite extends SparkFunSuite { .compareTo(convertedBack.getDecimal(3)) == 0) } + test("RowEncoder should preserve schema nullability") { + val schema = new StructType().add("int", IntegerType, nullable = false) + val encoder = RowEncoder(schema) + assert(encoder.serializer.length == 1) + assert(encoder.serializer.head.dataType == IntegerType) + assert(encoder.serializer.head.nullable == false) + } + private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { val encoder = RowEncoder(schema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d3b435ca5ed4..e953a6e8ef0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -26,7 +26,6 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.aggregate.TungstenAggregate @@ -1430,9 +1429,4 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { getMessage() assert(e1.startsWith("Path does not exist")) } - - test("SPARK-14139: map on row and preserve schema nullability") { - val df1 = Seq(1, 2, 3).toDF - assert(df1.map(row => Row(row.getInt(0) + 1))(RowEncoder(df1.schema)).schema === df1.schema) - } } From ad8c9ef160bc1042704207efa48919c8b8cb1313 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 15 Apr 2016 02:27:50 +0800 Subject: [PATCH 07/11] add runtime null check --- .../sql/catalyst/encoders/RowEncoder.scala | 18 +++++---- .../sql/catalyst/expressions/objects.scala | 40 +++++++++---------- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 84d1897e46ae..b61d5e609a24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -34,9 +34,8 @@ import org.apache.spark.unsafe.types.UTF8String object RowEncoder { def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - // We use an If expression to wrap extractorsFor result of StructType - val serializer = serializerFor(inputObject, schema).asInstanceOf[If].falseValue + val inputObject = BoundReference(0, ObjectType(cls), nullable = false) + val serializer = serializerFor(inputObject, schema) val deserializer = deserializerFor(schema) new ExpressionEncoder[Row]( schema, @@ -121,7 +120,7 @@ object RowEncoder { case StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => val fieldValue = serializerFor( - GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType), f.nullable), + GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType)), f.dataType ) if (f.nullable) { @@ -134,9 +133,14 @@ object RowEncoder { fieldValue } } - If(IsNull(inputObject), - Literal.create(null, inputType), - CreateStruct(convertedFields)) + + if (inputObject.nullable) { + If(IsNull(inputObject), + Literal.create(null, inputType), + CreateStruct(convertedFields)) + } else { + CreateStruct(convertedFields) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 07022cdda0d3..6879275c5a99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -711,36 +711,36 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) /** * Returns the value of field at index `index` from the external row `child`. + * + * Note that the input row and the field we try to get are both guaranteed to be not null, if they + * are null, a runtime exception will be thrown. */ case class GetExternalRowField( child: Expression, index: Int, - dataType: DataType, - override val nullable: Boolean) extends UnaryExpression with NonSQLExpression { + dataType: DataType) extends UnaryExpression with NonSQLExpression { + + override def nullable: Boolean = false override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - nullSafeCodeGen(ctx, ev, eval => { - val getField = dataType match { - case ObjectType(x) if x == classOf[Row] => s"""$eval.getStruct($index)""" - case _ => s"""((${ctx.boxedType(dataType)}) $eval.get($index))""" - } + val row = child.gen(ctx) - if (nullable) { - s""" - if ($eval.isNullAt($index)) { - ${ev.isNull} = true; - } else { - ${ev.value} = $getField; - } - """ - } else { - s""" - ${ev.value} = $getField; - """ + val getField = dataType match { + case ObjectType(x) if x == classOf[Row] => s"""$row.getStruct($index)""" + case _ => s"""(${ctx.boxedType(dataType)}) $row.get($index)""" + } + + ev.isNull = "false" + + s""" + ${row.code} + if (${row.isNull} || ${row.value}.isNullAt($index)) { + throw new RuntimeException("Runtime null check failed."); } - }) + final ${ctx.javaType(dataType)} ${ev.value} = $getField; + """ } } From 4ee75e95487d6ea59948d34a059e397e577c912c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 18 Apr 2016 11:43:44 +0800 Subject: [PATCH 08/11] add test --- .../sql/catalyst/expressions/objects.scala | 4 ++-- .../org/apache/spark/sql/DatasetSuite.scala | 18 +++++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 6879275c5a99..ae47fc048354 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -729,8 +729,8 @@ case class GetExternalRowField( val row = child.gen(ctx) val getField = dataType match { - case ObjectType(x) if x == classOf[Row] => s"""$row.getStruct($index)""" - case _ => s"""(${ctx.boxedType(dataType)}) $row.get($index)""" + case ObjectType(x) if x == classOf[Row] => s"""${row.value}.getStruct($index)""" + case _ => s"""(${ctx.boxedType(dataType)}) ${row.value}.get($index)""" } ev.isNull = "false" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d074535bf626..57c6e488622d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import scala.language.postfixOps -import org.apache.spark.sql.catalyst.encoders.OuterScopes +import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -626,6 +626,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext { // Make sure the generated code for this plan can compile and execute. checkDataset(wideDF.map(_.getLong(0)), 0L until 10 : _*) } + + test("runtime null check for RowEncoder") { + val schema = new StructType().add("i", IntegerType, nullable = false) + val df = sqlContext.range(10).map(l => { + if (l % 5 == 0) { + Row(null) + } else { + Row(l) + } + })(RowEncoder(schema)) + + val message = intercept[Exception] { + df.collect() + }.getMessage + assert(message.contains("Runtime null check failed")) + } } case class OtherTuple(_1: String, _2: Int) From 645f0a0168a1d594fd9457daeec14ca57b3442ea Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 19 Apr 2016 08:59:47 +0800 Subject: [PATCH 09/11] rebase --- .../org/apache/spark/sql/catalyst/expressions/objects.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 7d9502b0da8a..10f42266c2e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -728,8 +728,8 @@ case class GetExternalRowField( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val row = child.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + val row = child.genCode(ctx) val getField = dataType match { case ObjectType(x) if x == classOf[Row] => s"""${row.value}.getStruct($index)""" From 98e1463f8d30eaec9fee2e659a138943f61dd3f4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 5 May 2016 20:22:55 +0800 Subject: [PATCH 10/11] address comments --- .../spark/sql/catalyst/expressions/objects.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 7973ecbc15d7..dbaff1625ed5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -691,6 +691,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) /** * Returns the value of field at index `index` from the external row `child`. + * This class can be viewed as [[GetStructField]] for [[Row]]s instead of [[InternalRow]]s. * * Note that the input row and the field we try to get are both guaranteed to be not null, if they * are null, a runtime exception will be thrown. @@ -715,9 +716,15 @@ case class GetExternalRowField( val code = s""" ${row.code} - if (${row.isNull} || ${row.value}.isNullAt($index)) { - throw new RuntimeException("Runtime null check failed."); + + if (${row.isNull}) { + throw new RuntimeException("The input external row cannot be null."); + } + + if (${row.value}.isNullAt($index)) { + throw new RuntimeException("The ${index}th field of input row cannot be null."); } + final ${ctx.javaType(dataType)} ${ev.value} = $getField; """ ev.copy(code = code, isNull = "false") From 18aa1265fdf0b47b042ecbc9334d95db5c5f7229 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 5 May 2016 22:41:09 +0800 Subject: [PATCH 11/11] oops... --- sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 56707399f6a2..3cb4e52c6d41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -672,7 +672,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val message = intercept[Exception] { df.collect() }.getMessage - assert(message.contains("Runtime null check failed")) + assert(message.contains("The 0th field of input row cannot be null")) } }