From 7e6195ccf6008bce4ae19d88e612595589963b93 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 19 Feb 2018 19:31:36 +0000 Subject: [PATCH 01/14] initial commit --- .../codegen/GenerateUnsafeProjection.scala | 73 +++++++++++-------- .../GenerateUnsafeProjectionSuite.scala | 71 +++++++++++++++++- 2 files changed, 113 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 998a675eecc6..bf4f4af54ec4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -43,19 +43,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => false } - // TODO: if the nullability of field is correct, we can use it to save null check. private def writeStructToBuffer( ctx: CodegenContext, input: String, index: String, - fieldTypes: Seq[DataType], + fieldTypeAndNullables: Seq[(DataType, Boolean)], rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") - val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode( - JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"), - JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt)) + val fieldEvals = fieldTypeAndNullables.zipWithIndex.map { case ((dt, nullable), i) => + val isNull = if (nullable) { + JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)") + } else { + FalseLiteral + } + ExprCode(isNull, JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt)) } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -70,7 +72,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | // Remember the current cursor so that we can calculate how many bytes are | // written later. | final int $previousCursor = $rowWriter.cursor(); - | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} + | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypeAndNullables, structRowWriter)} | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); |} """.stripMargin @@ -80,7 +82,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, row: String, inputs: Seq[ExprCode], - inputTypes: Seq[DataType], + inputTypeAndNullables: Seq[(DataType, Boolean)], rowWriter: String, isTopLevel: Boolean = false): String = { val resetWriter = if (isTopLevel) { @@ -98,8 +100,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.resetRowWriter();" } - val writeFields = inputs.zip(inputTypes).zipWithIndex.map { - case ((input, dataType), index) => + val writeFields = inputs.zip(inputTypeAndNullables).zipWithIndex.map { + case ((input, (dataType, nullable)), index) => val dt = UserDefinedType.sqlType(dataType) val setNull = dt match { @@ -110,7 +112,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) - if (input.isNull == FalseLiteral) { + if (input.isNull == FalseLiteral || !nullable) { s""" |${input.code} |${writeField.trim} @@ -143,11 +145,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """.stripMargin } - // TODO: if the nullability of array element is correct, we can use it to save null check. private def writeArrayToBuffer( ctx: CodegenContext, input: String, elementType: DataType, + elementNullable: Boolean, rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") @@ -170,6 +172,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val element = CodeGenerator.getValue(tmpInput, et, index) + val primitiveTypeName = if (CodeGenerator.isPrimitiveType(jt)) { + CodeGenerator.primitiveTypeName(et) + } else { + "" + } + val elementAssignment = if (elementNullable) { + s""" + |if ($tmpInput.isNullAt($index)) { + | $arrayWriter.setNull$primitiveTypeName($index); + |} else { + | ${writeElement(ctx, element, index, et, arrayWriter)} + |} + """.stripMargin + } else { + writeElement(ctx, element, index, et, arrayWriter) + } + s""" |final ArrayData $tmpInput = $input; |if ($tmpInput instanceof UnsafeArrayData) { @@ -179,23 +198,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | $arrayWriter.initialize($numElements); | | for (int $index = 0; $index < $numElements; $index++) { - | if ($tmpInput.isNullAt($index)) { - | $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); - | } else { - | ${writeElement(ctx, element, index, et, arrayWriter)} - | } + | $elementAssignment | } |} """.stripMargin } - // TODO: if the nullability of value element is correct, we can use it to save null check. private def writeMapToBuffer( ctx: CodegenContext, input: String, index: String, keyType: DataType, valueType: DataType, + valueNullable: Boolean, rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") @@ -219,7 +234,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | // Remember the current cursor so that we can write numBytes of key array later. | final int $tmpCursor = $rowWriter.cursor(); | - | ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} + | ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, false, rowWriter)} | | // Write the numBytes of key array into the first 8 bytes. | Platform.putLong( @@ -227,7 +242,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | $tmpCursor - 8, | $rowWriter.cursor() - $tmpCursor); | - | ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} + | ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, valueNullable, rowWriter)} | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); |} """.stripMargin @@ -240,20 +255,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro dt: DataType, writer: String): String = dt match { case t: StructType => - writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer) + writeStructToBuffer(ctx, input, index, t.map(e => (e.dataType, e.nullable)), writer) - case ArrayType(et, _) => + case ArrayType(et, en) => val previousCursor = ctx.freshName("previousCursor") s""" |// Remember the current cursor so that we can calculate how many bytes are |// written later. |final int $previousCursor = $writer.cursor(); - |${writeArrayToBuffer(ctx, input, et, writer)} + |${writeArrayToBuffer(ctx, input, et, en, writer)} |$writer.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """.stripMargin - case MapType(kt, vt, _) => - writeMapToBuffer(ctx, input, index, kt, vt, writer) + case MapType(kt, vt, vn) => + writeMapToBuffer(ctx, input, index, kt, vt, vn, writer) case DecimalType.Fixed(precision, scale) => s"$writer.write($index, $input, $precision, $scale);" @@ -268,10 +283,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro expressions: Seq[Expression], useSubexprElimination: Boolean = false): ExprCode = { val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) - val exprTypes = expressions.map(_.dataType) + val exprTypeAndNullables = expressions.map(e => (e.dataType, e.nullable)) - val numVarLenFields = exprTypes.count { - case dt if UnsafeRow.isFixedLength(dt) => false + val numVarLenFields = exprTypeAndNullables.count { + case (dt, _) if UnsafeRow.isFixedLength(dt) => false // TODO: consider large decimal and interval type case _ => true } @@ -284,7 +299,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val evalSubexpr = ctx.subexprFunctions.mkString("\n") val writeExpressions = writeExpressionsToBuffer( - ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) + ctx, ctx.INPUT_ROW, exprEvals, exprTypeAndNullables, rowWriter, isTopLevel = true) val code = code""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala index e9d21f8a8ebc..01aa3579aea9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.BoundReference -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} -import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenerateUnsafeProjectionSuite extends SparkFunSuite { @@ -33,6 +33,41 @@ class GenerateUnsafeProjectionSuite extends SparkFunSuite { assert(!result.isNullAt(0)) assert(result.getStruct(0, 1).isNullAt(0)) } + + test("Test unsafe projection for array/map/struct") { + val dataType1 = ArrayType(StringType, false) + val exprs1 = BoundReference(0, dataType1, nullable = false) :: Nil + val projection1 = GenerateUnsafeProjection.generate(exprs1) + val result1 = projection1.apply(AlwaysNonNull) + assert(!result1.isNullAt(0)) + assert(!result1.getArray(0).isNullAt(0)) + assert(!result1.getArray(0).isNullAt(1)) + assert(!result1.getArray(0).isNullAt(2)) + + val dataType2 = MapType(StringType, StringType, false) + val exprs2 = BoundReference(0, dataType2, nullable = false) :: Nil + val projection2 = GenerateUnsafeProjection.generate(exprs2) + val result2 = projection2.apply(AlwaysNonNull) + assert(!result2.isNullAt(0)) + assert(!result2.getMap(0).keyArray.isNullAt(0)) + assert(!result2.getMap(0).keyArray.isNullAt(1)) + assert(!result2.getMap(0).keyArray.isNullAt(2)) + assert(!result2.getMap(0).valueArray.isNullAt(0)) + assert(!result2.getMap(0).valueArray.isNullAt(1)) + assert(!result2.getMap(0).valueArray.isNullAt(2)) + + val dataType3 = (new StructType) + .add("a", StringType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val exprs3 = BoundReference(0, dataType3, nullable = false) :: Nil + val projection3 = GenerateUnsafeProjection.generate(exprs3) + val result3 = projection3.apply(InternalRow(AlwaysNonNull)) + assert(!result3.isNullAt(0)) + assert(!result3.getStruct(0, 1).isNullAt(0)) + assert(!result3.getStruct(0, 2).isNullAt(0)) + assert(!result3.getStruct(0, 3).isNullAt(0)) + } } object AlwaysNull extends InternalRow { @@ -59,3 +94,35 @@ object AlwaysNull extends InternalRow { override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported private def notSupported: Nothing = throw new UnsupportedOperationException } + +object AlwaysNonNull extends InternalRow { + private def stringToUTF8Array(stringArray: Array[String]): ArrayData = { + val utf8Array = stringArray.map(s => UTF8String.fromString(s)).toArray + ArrayData.toArrayData(utf8Array) + } + override def numFields: Int = 1 + override def setNullAt(i: Int): Unit = {} + override def copy(): InternalRow = this + override def anyNull: Boolean = notSupported + override def isNullAt(ordinal: Int): Boolean = notSupported + override def update(i: Int, value: Any): Unit = notSupported + override def getBoolean(ordinal: Int): Boolean = notSupported + override def getByte(ordinal: Int): Byte = notSupported + override def getShort(ordinal: Int): Short = notSupported + override def getInt(ordinal: Int): Int = notSupported + override def getLong(ordinal: Int): Long = notSupported + override def getFloat(ordinal: Int): Float = notSupported + override def getDouble(ordinal: Int): Double = notSupported + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported + override def getUTF8String(ordinal: Int): UTF8String = UTF8String.fromString("test") + override def getBinary(ordinal: Int): Array[Byte] = notSupported + override def getInterval(ordinal: Int): CalendarInterval = notSupported + override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported + override def getArray(ordinal: Int): ArrayData = stringToUTF8Array(Array("1", "2", "3")) + val keyArray = stringToUTF8Array(Array("1", "2", "3")) + val valueArray = stringToUTF8Array(Array("a", "b", "c")) + override def getMap(ordinal: Int): MapData = new ArrayBasedMapData(keyArray, valueArray) + override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported + private def notSupported: Nothing = throw new UnsupportedOperationException + +} From a04a5d017ee4eabd45c2c21c4bb1cc23d1f1a546 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 8 Aug 2018 19:04:14 +0100 Subject: [PATCH 02/14] address review comments rebase with master --- .../codegen/GenerateUnsafeProjection.scala | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index bf4f4af54ec4..764a6d846583 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -72,7 +72,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | // Remember the current cursor so that we can calculate how many bytes are | // written later. | final int $previousCursor = $rowWriter.cursor(); - | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypeAndNullables, structRowWriter)} + | ${writeExpressionsToBuffer( + ctx, tmpInput, fieldEvals, fieldTypeAndNullables.map(_._1), structRowWriter)} | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); |} """.stripMargin @@ -82,7 +83,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, row: String, inputs: Seq[ExprCode], - inputTypeAndNullables: Seq[(DataType, Boolean)], + inputType: Seq[DataType], rowWriter: String, isTopLevel: Boolean = false): String = { val resetWriter = if (isTopLevel) { @@ -100,8 +101,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.resetRowWriter();" } - val writeFields = inputs.zip(inputTypeAndNullables).zipWithIndex.map { - case ((input, (dataType, nullable)), index) => + val writeFields = inputs.zip(inputType).zipWithIndex.map { + case ((input, dataType), index) => val dt = UserDefinedType.sqlType(dataType) val setNull = dt match { @@ -112,7 +113,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) - if (input.isNull == FalseLiteral || !nullable) { + if (input.isNull == FalseLiteral) { s""" |${input.code} |${writeField.trim} @@ -234,7 +235,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | // Remember the current cursor so that we can write numBytes of key array later. | final int $tmpCursor = $rowWriter.cursor(); | - | ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, false, rowWriter)} + | ${writeArrayToBuffer( + ctx, s"$tmpInput.keyArray()", keyType, false, rowWriter)} | | // Write the numBytes of key array into the first 8 bytes. | Platform.putLong( @@ -242,7 +244,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | $tmpCursor - 8, | $rowWriter.cursor() - $tmpCursor); | - | ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, valueNullable, rowWriter)} + | ${writeArrayToBuffer( + ctx, s"$tmpInput.valueArray()", valueType, valueNullable, rowWriter)} | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); |} """.stripMargin @@ -283,12 +286,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro expressions: Seq[Expression], useSubexprElimination: Boolean = false): ExprCode = { val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) - val exprTypeAndNullables = expressions.map(e => (e.dataType, e.nullable)) + val exprType = expressions.map(_.dataType) - val numVarLenFields = exprTypeAndNullables.count { - case (dt, _) if UnsafeRow.isFixedLength(dt) => false + val numVarLenFields = exprType.count { + case dt => !UnsafeRow.isFixedLength(dt) // TODO: consider large decimal and interval type - case _ => true } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -299,7 +301,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val evalSubexpr = ctx.subexprFunctions.mkString("\n") val writeExpressions = writeExpressionsToBuffer( - ctx, ctx.INPUT_ROW, exprEvals, exprTypeAndNullables, rowWriter, isTopLevel = true) + ctx, ctx.INPUT_ROW, exprEvals, exprType, rowWriter, isTopLevel = true) val code = code""" From 172ee593b71e7e54cd6f707a0169c0797eb30fba Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 9 Aug 2018 02:22:42 +0100 Subject: [PATCH 03/14] address review comments --- .../codegen/GenerateUnsafeProjection.scala | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 764a6d846583..117d3ca0649c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -64,6 +64,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") val previousCursor = ctx.freshName("previousCursor") + val structExpressions = writeExpressionsToBuffer( + ctx, tmpInput, fieldEvals, fieldTypeAndNullables.map(_._1), structRowWriter) s""" |final InternalRow $tmpInput = $input; |if ($tmpInput instanceof UnsafeRow) { @@ -72,8 +74,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | // Remember the current cursor so that we can calculate how many bytes are | // written later. | final int $previousCursor = $rowWriter.cursor(); - | ${writeExpressionsToBuffer( - ctx, tmpInput, fieldEvals, fieldTypeAndNullables.map(_._1), structRowWriter)} + | $structExpressions | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); |} """.stripMargin @@ -181,7 +182,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val elementAssignment = if (elementNullable) { s""" |if ($tmpInput.isNullAt($index)) { - | $arrayWriter.setNull$primitiveTypeName($index); + | $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); |} else { | ${writeElement(ctx, element, index, et, arrayWriter)} |} @@ -219,6 +220,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val previousCursor = ctx.freshName("previousCursor") // Writes out unsafe map according to the format described in `UnsafeMapData`. + val keyArray = writeArrayToBuffer( + ctx, s"$tmpInput.keyArray()", keyType, false, rowWriter) + val valueArray = writeArrayToBuffer( + ctx, s"$tmpInput.valueArray()", valueType, valueNullable, rowWriter) + s""" |final MapData $tmpInput = $input; |if ($tmpInput instanceof UnsafeMapData) { @@ -235,8 +241,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | // Remember the current cursor so that we can write numBytes of key array later. | final int $tmpCursor = $rowWriter.cursor(); | - | ${writeArrayToBuffer( - ctx, s"$tmpInput.keyArray()", keyType, false, rowWriter)} + | $keyArray | | // Write the numBytes of key array into the first 8 bytes. | Platform.putLong( @@ -244,8 +249,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | $tmpCursor - 8, | $rowWriter.cursor() - $tmpCursor); | - | ${writeArrayToBuffer( - ctx, s"$tmpInput.valueArray()", valueType, valueNullable, rowWriter)} + | $valueArray | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); |} """.stripMargin From 427b5298c3a26e313bf93b5a89e501e23404be6c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 10 Aug 2018 03:43:47 +0100 Subject: [PATCH 04/14] fix test failure of JsonExpressionsSuite --- .../expressions/codegen/GenerateUnsafeProjection.scala | 6 +++--- .../sql/catalyst/expressions/JsonExpressionsSuite.scala | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 117d3ca0649c..137dc34cdded 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -290,9 +290,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro expressions: Seq[Expression], useSubexprElimination: Boolean = false): ExprCode = { val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) - val exprType = expressions.map(_.dataType) + val exprTypes = expressions.map(_.dataType) - val numVarLenFields = exprType.count { + val numVarLenFields = exprTypes.count { case dt => !UnsafeRow.isFixedLength(dt) // TODO: consider large decimal and interval type } @@ -305,7 +305,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val evalSubexpr = ctx.subexprFunctions.mkString("\n") val writeExpressions = writeExpressionsToBuffer( - ctx, ctx.INPUT_ROW, exprEvals, exprType, rowWriter, isTopLevel = true) + ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) val code = code""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 04f1c8ce0b83..0e9c8abec33e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -694,7 +694,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with |""".stripMargin val jsonSchema = new StructType() .add("a", LongType, nullable = false) - .add("b", StringType, nullable = false) + .add("b", StringType, nullable = !forceJsonNullableSchema) .add("c", StringType, nullable = false) val output = InternalRow(1L, null, UTF8String.fromString("foo")) val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) From c6146e14899e31541f3d72faafbdda33bf6ee178 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 15 Aug 2018 17:08:58 +0100 Subject: [PATCH 05/14] address review comment --- .../codegen/GenerateUnsafeProjection.scala | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 137dc34cdded..f4362a6807eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.types._ */ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { + case class Schema(dataType: DataType, nullable: Boolean) + /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = UserDefinedType.sqlType(dataType) match { case NullType => true @@ -47,17 +49,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, index: String, - fieldTypeAndNullables: Seq[(DataType, Boolean)], + fieldTypeAndNullables: Seq[Schema], rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") - val fieldEvals = fieldTypeAndNullables.zipWithIndex.map { case ((dt, nullable), i) => - val isNull = if (nullable) { + val fieldEvals = fieldTypeAndNullables.zipWithIndex.map { case (dtNullable, i) => + val isNull = if (dtNullable.nullable) { JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)") } else { FalseLiteral } - ExprCode(isNull, JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt)) + ExprCode(isNull, JavaCode.expression( + CodeGenerator.getValue(tmpInput, dtNullable.dataType, i.toString), dtNullable.dataType)) } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -65,7 +68,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") val previousCursor = ctx.freshName("previousCursor") val structExpressions = writeExpressionsToBuffer( - ctx, tmpInput, fieldEvals, fieldTypeAndNullables.map(_._1), structRowWriter) + ctx, tmpInput, fieldEvals, fieldTypeAndNullables.map(_.dataType), structRowWriter) s""" |final InternalRow $tmpInput = $input; |if ($tmpInput instanceof UnsafeRow) { @@ -174,11 +177,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val element = CodeGenerator.getValue(tmpInput, et, index) - val primitiveTypeName = if (CodeGenerator.isPrimitiveType(jt)) { - CodeGenerator.primitiveTypeName(et) - } else { - "" - } val elementAssignment = if (elementNullable) { s""" |if ($tmpInput.isNullAt($index)) { @@ -262,7 +260,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro dt: DataType, writer: String): String = dt match { case t: StructType => - writeStructToBuffer(ctx, input, index, t.map(e => (e.dataType, e.nullable)), writer) + writeStructToBuffer( + ctx, input, index, t.map(e => Schema(e.dataType, e.nullable)), writer) case ArrayType(et, en) => val previousCursor = ctx.freshName("previousCursor") From 5f14c157f32f687bb14dcfe779897a6ad8cfe739 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 17 Aug 2018 15:51:56 +0100 Subject: [PATCH 06/14] updates --- .../codegen/GenerateUnsafeProjection.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index f4362a6807eb..e12ad559f371 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -87,7 +87,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, row: String, inputs: Seq[ExprCode], - inputType: Seq[DataType], + inputTypes: Seq[DataType], rowWriter: String, isTopLevel: Boolean = false): String = { val resetWriter = if (isTopLevel) { @@ -105,7 +105,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.resetRowWriter();" } - val writeFields = inputs.zip(inputType).zipWithIndex.map { + val writeFields = inputs.zip(inputTypes).zipWithIndex.map { case ((input, dataType), index) => val dt = UserDefinedType.sqlType(dataType) @@ -154,7 +154,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, elementType: DataType, - elementNullable: Boolean, + containsNull: Boolean, rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") @@ -177,7 +177,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val element = CodeGenerator.getValue(tmpInput, et, index) - val elementAssignment = if (elementNullable) { + val elementAssignment = if (containsNull) { s""" |if ($tmpInput.isNullAt($index)) { | $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); @@ -210,7 +210,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro index: String, keyType: DataType, valueType: DataType, - valueNullable: Boolean, + valueContainsNull: Boolean, rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") @@ -221,7 +221,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val keyArray = writeArrayToBuffer( ctx, s"$tmpInput.keyArray()", keyType, false, rowWriter) val valueArray = writeArrayToBuffer( - ctx, s"$tmpInput.valueArray()", valueType, valueNullable, rowWriter) + ctx, s"$tmpInput.valueArray()", valueType, valueContainsNull, rowWriter) s""" |final MapData $tmpInput = $input; From 6957dc83cb83cdf4e7b4c37828490058288d3f51 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 17 Aug 2018 15:53:52 +0100 Subject: [PATCH 07/14] improve test coverage regarding nullable in Unsafe structures --- .../expressions/ExpressionEvalHelper.scala | 12 ++++++---- .../ExpressionEvalHelperSuite.scala | 22 +++++++++++++++++-- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 6684e5ce18d4..b1b891c332c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -59,12 +59,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa // Make it as method to obtain fresh expression everytime. def expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) - checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) - checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) + // checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) + // checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) + print(s"checkEvaluation: ${expr.dataType}\n") if (GenerateUnsafeProjection.canSupport(expr.dataType)) { + print(s"HERE\n") checkEvaluationWithUnsafeProjection(expr, catalystValue, inputRow) } - checkEvaluationWithOptimization(expr, catalystValue, inputRow) + // checkEvaluationWithOptimization(expr, catalystValue, inputRow) } /** @@ -92,6 +94,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa var i = 0 while (isSame && i < result.numElements) { isSame = checkResult(result.get(i, et), expected.get(i, et), et) + print(s"[$i]: ${result.isNullAt(i)}, ${result.get(i, et)}, ${expected.isNullAt(i)}, ${expected.get(i, et)}, $isSame\n") i += 1 } isSame @@ -223,8 +226,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa } } else { val lit = InternalRow(expected, expected) + val dtAsNullable = expression.dataType.asNullable val expectedRow = - UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) + UnsafeProjection.create(Array(dtAsNullable, dtAsNullable)).apply(lit) if (unsafeRow != expectedRow) { fail("Incorrect evaluation in unsafe mode: " + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala index 7c7c4cccee25..5104eee40bc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType} /** * A test suite for testing [[ExpressionEvalHelper]]. @@ -35,6 +35,24 @@ class ExpressionEvalHelperSuite extends SparkFunSuite with ExpressionEvalHelper val e = intercept[RuntimeException] { checkEvaluation(BadCodegenExpression(), 10) } assert(e.getMessage.contains("some_variable")) } + + test("SPARK-23466: checkEvaluationWithUnsafeProjection should fail if null is compared with " + + "primitive default value") { + val expected = Array(null, -1, 0, 1) + val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) + + val expression1 = CreateArray( + Seq(Literal(null, IntegerType), Literal(-1), Literal(0), Literal(1))) + assert(expression1.dataType.containsNull) + checkEvaluationWithUnsafeProjection(expression1, catalystValue) + + val expression2 = CreateArray(Seq(Literal(0, IntegerType), Literal(-1), Literal(0), Literal(1))) + assert(!expression2.dataType.containsNull) + val e = intercept[RuntimeException] { + checkEvaluationWithUnsafeProjection(expression2, catalystValue) + } + assert(e.getMessage.contains("Incorrect evaluation in unsafe mode")) + } } /** From 14590b712b3d240d30ba448c4ed38a012e6779a0 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 17 Aug 2018 16:27:41 +0100 Subject: [PATCH 08/14] remove debug code --- .../sql/catalyst/expressions/ExpressionEvalHelper.scala | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b1b891c332c8..9559ea14ad8f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -59,14 +59,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa // Make it as method to obtain fresh expression everytime. def expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) - // checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) - // checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) - print(s"checkEvaluation: ${expr.dataType}\n") + checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) + checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) if (GenerateUnsafeProjection.canSupport(expr.dataType)) { - print(s"HERE\n") checkEvaluationWithUnsafeProjection(expr, catalystValue, inputRow) } - // checkEvaluationWithOptimization(expr, catalystValue, inputRow) + checkEvaluationWithOptimization(expr, catalystValue, inputRow) } /** @@ -94,7 +92,6 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa var i = 0 while (isSame && i < result.numElements) { isSame = checkResult(result.get(i, et), expected.get(i, et), et) - print(s"[$i]: ${result.isNullAt(i)}, ${result.get(i, et)}, ${expected.isNullAt(i)}, ${expected.get(i, et)}, $isSame\n") i += 1 } isSame From e89954ae131b9c3ffb1163fa19a6549d5a1c01c4 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 20 Aug 2018 20:12:29 +0100 Subject: [PATCH 09/14] address review comments --- .../expressions/codegen/GenerateUnsafeProjection.scala | 7 +++---- .../catalyst/expressions/ExpressionEvalHelperSuite.scala | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index e12ad559f371..087d3cc15221 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -53,14 +53,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") - val fieldEvals = fieldTypeAndNullables.zipWithIndex.map { case (dtNullable, i) => - val isNull = if (dtNullable.nullable) { + val fieldEvals = fieldTypeAndNullables.zipWithIndex.map { case (Schema(dt, nullable), i) => + val isNull = if (nullable) { JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)") } else { FalseLiteral } - ExprCode(isNull, JavaCode.expression( - CodeGenerator.getValue(tmpInput, dtNullable.dataType, i.toString), dtNullable.dataType)) + ExprCode(isNull, JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt)) } val rowWriterClass = classOf[UnsafeRowWriter].getName diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala index 5104eee40bc5..8b3464999110 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType} +import org.apache.spark.sql.types.{DataType, IntegerType} /** * A test suite for testing [[ExpressionEvalHelper]]. From b346434ca98d08367e3e6a913f91994a9b56b25b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 21 Aug 2018 19:00:02 +0100 Subject: [PATCH 10/14] address review comment --- .../codegen/GenerateUnsafeProjection.scala | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 087d3cc15221..7293ccf42693 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -49,11 +49,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, index: String, - fieldTypeAndNullables: Seq[Schema], + schemas: Seq[Schema], rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") - val fieldEvals = fieldTypeAndNullables.zipWithIndex.map { case (Schema(dt, nullable), i) => + val fieldEvals = schemas.zipWithIndex.map { case (Schema(dt, nullable), i) => val isNull = if (nullable) { JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)") } else { @@ -67,7 +67,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") val previousCursor = ctx.freshName("previousCursor") val structExpressions = writeExpressionsToBuffer( - ctx, tmpInput, fieldEvals, fieldTypeAndNullables.map(_.dataType), structRowWriter) + ctx, tmpInput, fieldEvals, schemas, structRowWriter) s""" |final InternalRow $tmpInput = $input; |if ($tmpInput instanceof UnsafeRow) { @@ -86,7 +86,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, row: String, inputs: Seq[ExprCode], - inputTypes: Seq[DataType], + schemas: Seq[Schema], rowWriter: String, isTopLevel: Boolean = false): String = { val resetWriter = if (isTopLevel) { @@ -104,8 +104,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.resetRowWriter();" } - val writeFields = inputs.zip(inputTypes).zipWithIndex.map { - case ((input, dataType), index) => + val writeFields = inputs.zip(schemas).zipWithIndex.map { + case ((input, Schema(dataType, nullable)), index) => val dt = UserDefinedType.sqlType(dataType) val setNull = dt match { @@ -116,7 +116,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) - if (input.isNull == FalseLiteral) { + if (input.isNull == FalseLiteral || !nullable) { s""" |${input.code} |${writeField.trim} @@ -288,10 +288,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro expressions: Seq[Expression], useSubexprElimination: Boolean = false): ExprCode = { val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) - val exprTypes = expressions.map(_.dataType) + val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable)) - val numVarLenFields = exprTypes.count { - case dt => !UnsafeRow.isFixedLength(dt) + val numVarLenFields = exprSchemas.count { + case Schema(dt, _) => !UnsafeRow.isFixedLength(dt) // TODO: consider large decimal and interval type } @@ -303,7 +303,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val evalSubexpr = ctx.subexprFunctions.mkString("\n") val writeExpressions = writeExpressionsToBuffer( - ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) + ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true) val code = code""" From 980ca2ebf937fd77a89de1498b896145eb3dc199 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 22 Aug 2018 16:37:18 +0100 Subject: [PATCH 11/14] address review comment --- .../catalyst/expressions/codegen/GenerateUnsafeProjection.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7293ccf42693..c8aedf582edb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -116,7 +116,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) - if (input.isNull == FalseLiteral || !nullable) { + if (!nullable) { s""" |${input.code} |${writeField.trim} From bbc9340e2bf047d630b2478ee8d15e5b2c084c58 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 24 Aug 2018 11:28:10 +0100 Subject: [PATCH 12/14] address review comment --- .../expressions/codegen/GenerateUnsafeProjection.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index c8aedf582edb..0ecd0de8d820 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -66,8 +66,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") val previousCursor = ctx.freshName("previousCursor") - val structExpressions = writeExpressionsToBuffer( - ctx, tmpInput, fieldEvals, schemas, structRowWriter) s""" |final InternalRow $tmpInput = $input; |if ($tmpInput instanceof UnsafeRow) { @@ -76,7 +74,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | // Remember the current cursor so that we can calculate how many bytes are | // written later. | final int $previousCursor = $rowWriter.cursor(); - | $structExpressions + | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, schemas, structRowWriter)} | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); |} """.stripMargin From 37dc4d8da0d9654c36494b516cfdfb869e66afc2 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 30 Aug 2018 17:14:53 +0100 Subject: [PATCH 13/14] address review comments --- .../ExpressionEvalHelperSuite.scala | 20 +------------------ 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala index 8b3464999110..7c7c4cccee25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -35,24 +35,6 @@ class ExpressionEvalHelperSuite extends SparkFunSuite with ExpressionEvalHelper val e = intercept[RuntimeException] { checkEvaluation(BadCodegenExpression(), 10) } assert(e.getMessage.contains("some_variable")) } - - test("SPARK-23466: checkEvaluationWithUnsafeProjection should fail if null is compared with " + - "primitive default value") { - val expected = Array(null, -1, 0, 1) - val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) - - val expression1 = CreateArray( - Seq(Literal(null, IntegerType), Literal(-1), Literal(0), Literal(1))) - assert(expression1.dataType.containsNull) - checkEvaluationWithUnsafeProjection(expression1, catalystValue) - - val expression2 = CreateArray(Seq(Literal(0, IntegerType), Literal(-1), Literal(0), Literal(1))) - assert(!expression2.dataType.containsNull) - val e = intercept[RuntimeException] { - checkEvaluationWithUnsafeProjection(expression2, catalystValue) - } - assert(e.getMessage.contains("Incorrect evaluation in unsafe mode")) - } } /** From 88c74c61d5ab64eb860b91261d013702983ac49c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 31 Aug 2018 16:03:32 +0100 Subject: [PATCH 14/14] address review comment --- .../spark/sql/catalyst/expressions/ExpressionEvalHelper.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 9559ea14ad8f..6684e5ce18d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -223,9 +223,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa } } else { val lit = InternalRow(expected, expected) - val dtAsNullable = expression.dataType.asNullable val expectedRow = - UnsafeProjection.create(Array(dtAsNullable, dtAsNullable)).apply(lit) + UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) if (unsafeRow != expectedRow) { fail("Incorrect evaluation in unsafe mode: " + s"$expression, actual: $unsafeRow, expected: $expectedRow$input")