diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 568927a35e87..d2e461b2cb69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -406,19 +406,21 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit if (row.numFields > 0) { val st = fields.map(_.dataType) val toUTF8StringFuncs = st.map(castToString) - if (row.isNullAt(0)) { + if (fields(0).nullable && row.isNullAt(0)) { if (!legacyCastToStr) builder.append("null") } else { - builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))).asInstanceOf[UTF8String]) + val accessor = InternalRow.getAccessor(fields(0).dataType, fields(0).nullable) + builder.append(toUTF8StringFuncs(0)(accessor(row, 0)).asInstanceOf[UTF8String]) } var i = 1 while (i < row.numFields) { builder.append(",") - if (row.isNullAt(i)) { + if (fields(i).nullable && row.isNullAt(i)) { if (!legacyCastToStr) builder.append(" null") } else { builder.append(" ") - builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))).asInstanceOf[UTF8String]) + val accessor = InternalRow.getAccessor(fields(i).dataType, fields(i).nullable) + builder.append(toUTF8StringFuncs(i)(accessor(row, i)).asInstanceOf[UTF8String]) } i += 1 } @@ -868,8 +870,13 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit val newRow = new GenericInternalRow(from.fields.length) var i = 0 while (i < row.numFields) { - newRow.update(i, - if (row.isNullAt(i)) null else castFuncs(i)(row.get(i, from.apply(i).dataType))) + val value = if (from.fields(i).nullable && row.isNullAt(i)) { + null + } else { + val accessor = InternalRow.getAccessor(from.fields(i).dataType, from.fields(i).nullable) + castFuncs(i)(accessor(row, i)) + } + newRow.update(i, value) i += 1 } newRow @@ -1098,29 +1105,37 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } private def writeStructToStringBuilder( - st: Seq[DataType], + st: Seq[StructField], row: ExprValue, buffer: ExprValue, ctx: CodegenContext): Block = { - val structToStringCode = st.zipWithIndex.map { case (ft, i) => - val fieldToStringCode = castToStringCode(ft, ctx) - val field = ctx.freshVariable("field", ft) - val fieldStr = ctx.freshVariable("fieldStr", StringType) - val javaType = JavaCode.javaType(ft) - code""" - |${if (i != 0) code"""$buffer.append(",");""" else EmptyBlock} - |if ($row.isNullAt($i)) { - | ${appendIfNotLegacyCastToStr(buffer, if (i == 0) "null" else " null")} - |} else { - | ${if (i != 0) code"""$buffer.append(" ");""" else EmptyBlock} - | - | // Append $i field into the string buffer - | $javaType $field = ${CodeGenerator.getValue(row, ft, s"$i")}; - | UTF8String $fieldStr = null; - | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} - | $buffer.append($fieldStr); - |} - """.stripMargin + val structToStringCode = st.zipWithIndex.map { + case (StructField(_, dataType, nullable, _), i) => + val fieldToStringCode = castToStringCode(dataType, ctx) + val field = ctx.freshVariable("field", dataType) + val fieldStr = ctx.freshVariable("fieldStr", StringType) + val javaType = JavaCode.javaType(dataType) + + val isNull = if (nullable) { + code"$row.isNullAt($i)" + } else { + code"false" + } + + code""" + |${if (i != 0) code"""$buffer.append(",");""" else EmptyBlock} + |if ($isNull) { + | ${appendIfNotLegacyCastToStr(buffer, if (i == 0) "null" else " null")} + |} else { + | ${if (i != 0) code"""$buffer.append(" ");""" else EmptyBlock} + | + | // Append $i field into the string buffer + | $javaType $field = ${CodeGenerator.getValue(row, dataType, s"$i")}; + | UTF8String $fieldStr = null; + | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} + | $buffer.append($fieldStr); + |} + """.stripMargin } val writeStructCode = ctx.splitExpressions( @@ -1184,7 +1199,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit val row = ctx.freshVariable("row", classOf[InternalRow]) val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) - val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) + val writeStructCode = writeStructToStringBuilder(fields, row, buffer, ctx) code""" |InternalRow $row = $c; |$bufferClass $buffer = new $bufferClass(); @@ -1890,8 +1905,15 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit val toFieldNull = ctx.freshVariable("tfn", BooleanType) val fromType = JavaCode.javaType(from.fields(i).dataType) val setColumn = CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim) + + val isNull = if (from.fields(i).nullable) { + code"boolean $fromFieldNull = $tmpInput.isNullAt($i);" + } else { + code"boolean $fromFieldNull = false;" + } + code""" - boolean $fromFieldNull = $tmpInput.isNullAt($i); + $isNull if ($fromFieldNull) { $tmpResult.setNullAt($i); } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index 48835bf9db2c..f10a82c6ca44 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -1271,4 +1271,49 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { } } } + + test("SPARK-35912: Cast struct contains the null value to string") { + Seq(true, false).foreach { nullable => + val lit = Literal.create(InternalRow(InternalRow(1, null)), + StructType(Seq(StructField("c1", + StructType(Seq( + StructField("c2", IntegerType, true), + StructField("c3", IntegerType, nullable) + )) + ))) + ) + val ret = cast(lit, StringType) + assert(ret.resolved) + val expected = if (nullable) { + "{{1, null}}" + } else { + "{{1, 0}}" + } + checkEvaluation(ret, expected) + } + } + + test("SPARK-35912: Cast struct contains the null value to struct") { + Seq(true, false).foreach { nullable => + val lit = Literal.create(InternalRow(1, null), + StructType(Seq( + StructField("c1", IntegerType, true), + StructField("c2", IntegerType, nullable) + )) + ) + val toType = StructType(Seq( + StructField("c1", StringType, true), + StructField("c2", StringType, true) + )) + + val expected = if (nullable) { + InternalRow(UTF8String.fromString("1"), null) + } else { + InternalRow(UTF8String.fromString("1"), UTF8String.fromString("0")) + } + val ret = cast(lit, toType) + assert(ret.resolved) + checkEvaluation(ret, expected) + } + } }