diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 89fe7c48c000..b61583d0dafb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -484,7 +484,7 @@ case class JsonTuple(children: Seq[Expression]) * Converts an json input string to a [[StructType]] with the specified schema. */ case class JsonToStruct(schema: StructType, options: Map[String, String], child: Expression) - extends Expression with CodegenFallback with ExpectsInputTypes { + extends UnaryExpression with CodegenFallback with ExpectsInputTypes { override def nullable: Boolean = true @transient @@ -495,11 +495,8 @@ case class JsonToStruct(schema: StructType, options: Map[String, String], child: new JSONOptions(options ++ Map("mode" -> ParseModes.FAIL_FAST_MODE))) override def dataType: DataType = schema - override def children: Seq[Expression] = child :: Nil - override def eval(input: InternalRow): Any = { - val json = child.eval(input) - if (json == null) return null + override def nullSafeEval(json: Any): Any = { try parser.parse(json.toString).head catch { case _: SparkSQLJsonProcessingException => null } @@ -512,7 +509,7 @@ case class JsonToStruct(schema: StructType, options: Map[String, String], child: * Converts a [[StructType]] to a json output string. */ case class StructToJson(options: Map[String, String], child: Expression) - extends Expression with CodegenFallback with ExpectsInputTypes { + extends UnaryExpression with CodegenFallback with ExpectsInputTypes { override def nullable: Boolean = true @transient @@ -523,7 +520,6 @@ case class StructToJson(options: Map[String, String], child: Expression) new JacksonGenerator(child.dataType.asInstanceOf[StructType], writer) override def dataType: DataType = StringType - override def children: Seq[Expression] = child :: Nil override def checkInputDataTypes(): TypeCheckResult = { if (StructType.acceptsType(child.dataType)) { @@ -540,8 +536,8 @@ case class StructToJson(options: Map[String, String], child: Expression) } } - override def eval(input: InternalRow): Any = { - gen.write(child.eval(input).asInstanceOf[InternalRow]) + override def nullSafeEval(row: Any): Any = { + gen.write(row.asInstanceOf[InternalRow]) gen.flush() val json = writer.toString writer.reset() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 3bfa0bfda620..3b0e90824b76 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.ParseModes -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -347,7 +347,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal(null)), + JsonToStruct(schema, Map.empty, Literal.create(null, StringType)), null ) } @@ -360,4 +360,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { """{"a":1}""" ) } + + test("to_json null input column") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + val struct = Literal.create(null, schema) + checkEvaluation( + StructToJson(Map.empty, struct), + null + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 59ae889cf3b9..7d63d31d9b97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -141,4 +141,18 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.getMessage.contains( "Unable to convert column a of type calendarinterval to JSON.")) } + + test("roundtrip in to_json and from_json") { + val dfOne = Seq(Some(Tuple1(Tuple1(1))), None).toDF("struct") + val schemaOne = dfOne.schema(0).dataType.asInstanceOf[StructType] + val readBackOne = dfOne.select(to_json($"struct").as("json")) + .select(from_json($"json", schemaOne).as("struct")) + checkAnswer(dfOne, readBackOne) + + val dfTwo = Seq(Some("""{"a":1}"""), None).toDF("json") + val schemaTwo = new StructType().add("a", IntegerType) + val readBackTwo = dfTwo.select(from_json($"json", schemaTwo).as("struct")) + .select(to_json($"struct").as("json")) + checkAnswer(dfTwo, readBackTwo) + } }