diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 3f3d6b2b63a0..56bd3d7026d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -27,19 +27,22 @@ import org.apache.spark.unsafe.types.UTF8String object ExprUtils { - def evalSchemaExpr(exp: Expression): StructType = { - // Use `DataType.fromDDL` since the type string can be struct<...>. - val dataType = exp match { - case Literal(s, StringType) => - DataType.fromDDL(s.toString) - case e @ SchemaOfCsv(_: Literal, _) => - val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String] - DataType.fromDDL(ddlSchema.toString) - case e => throw new AnalysisException( + def evalTypeExpr(exp: Expression): DataType = { + if (exp.foldable) { + exp.eval() match { + case s: UTF8String if s != null => DataType.fromDDL(s.toString) + case _ => throw new AnalysisException( + s"The expression '${exp.sql}' is not a valid schema string.") + } + } else { + throw new AnalysisException( "Schema should be specified in DDL format as a string literal or output of " + - s"the schema_of_csv function instead of ${e.sql}") + s"the schema_of_json/schema_of_csv functions instead of ${exp.sql}") } + } + def evalSchemaExpr(exp: Expression): StructType = { + val dataType = evalTypeExpr(exp) if (!dataType.isInstanceOf[StructType]) { throw new AnalysisException( s"Schema should be struct type but got ${dataType.sql}.") @@ -47,16 +50,6 @@ object ExprUtils { dataType.asInstanceOf[StructType] } - def evalTypeExpr(exp: Expression): DataType = exp match { - case Literal(s, StringType) => DataType.fromDDL(s.toString) - case e @ SchemaOfJson(_: Literal, _) => - val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String] - DataType.fromDDL(ddlSchema.toString) - case e => throw new AnalysisException( - "Schema should be specified in DDL format as a string literal or output of " + - s"the schema_of_json function instead of ${e.sql}") - } - def convertToMapData(exp: Expression): Map[String, String] = exp match { case m: CreateMap if m.dataType.acceptsType(MapType(StringType, StringType, valueContainsNull = false)) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 54af314fe417..5140db90c595 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -165,10 +165,14 @@ case class SchemaOfCsv( @transient private lazy val csv = child.eval().asInstanceOf[UTF8String] - override def checkInputDataTypes(): TypeCheckResult = child match { - case Literal(s, StringType) if s != null => super.checkInputDataTypes() - case _ => TypeCheckResult.TypeCheckFailure( - s"The input csv should be a string literal and not null; however, got ${child.sql}.") + override def checkInputDataTypes(): TypeCheckResult = { + if (child.foldable && csv != null) { + super.checkInputDataTypes() + } else { + TypeCheckResult.TypeCheckFailure( + "The input csv should be a foldable string expression and not null; " + + s"however, got ${child.sql}.") + } } override def eval(v: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 61afdb6c9492..aa4b464850f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -764,10 +764,14 @@ case class SchemaOfJson( @transient private lazy val json = child.eval().asInstanceOf[UTF8String] - override def checkInputDataTypes(): TypeCheckResult = child match { - case Literal(s, StringType) if s != null => super.checkInputDataTypes() - case _ => TypeCheckResult.TypeCheckFailure( - s"The input json should be a string literal and not null; however, got ${child.sql}.") + override def checkInputDataTypes(): TypeCheckResult = { + if (child.foldable && json != null) { + super.checkInputDataTypes() + } else { + TypeCheckResult.TypeCheckFailure( + "The input json should be a foldable string expression and not null; " + + s"however, got ${child.sql}.") + } } override def eval(v: InternalRow): Any = { diff --git a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out index 8495bef9122e..be7fa5e9d5ff 100644 --- a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out @@ -24,7 +24,7 @@ select from_csv('1', 1) struct<> -- !query output org.apache.spark.sql.AnalysisException -Schema should be specified in DDL format as a string literal or output of the schema_of_csv function instead of 1;; line 1 pos 7 +The expression '1' is not a valid schema string.;; line 1 pos 7 -- !query @@ -91,7 +91,7 @@ select schema_of_csv(null) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'schema_of_csv(NULL)' due to data type mismatch: The input csv should be a string literal and not null; however, got NULL.; line 1 pos 7 +cannot resolve 'schema_of_csv(NULL)' due to data type mismatch: The input csv should be a foldable string expression and not null; however, got NULL.; line 1 pos 7 -- !query @@ -108,7 +108,7 @@ SELECT schema_of_csv(csvField) FROM csvTable struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'schema_of_csv(csvtable.`csvField`)' due to data type mismatch: The input csv should be a string literal and not null; however, got csvtable.`csvField`.; line 1 pos 7 +cannot resolve 'schema_of_csv(csvtable.`csvField`)' due to data type mismatch: The input csv should be a foldable string expression and not null; however, got csvtable.`csvField`.; line 1 pos 7 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 21a3531caf73..920b45a8fa77 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -115,7 +115,7 @@ select from_json('{"a":1}', 1) struct<> -- !query output org.apache.spark.sql.AnalysisException -Schema should be specified in DDL format as a string literal or output of the schema_of_json function instead of 1;; line 1 pos 7 +The expression '1' is not a valid schema string.;; line 1 pos 7 -- !query @@ -326,7 +326,7 @@ select schema_of_json(null) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'schema_of_json(NULL)' due to data type mismatch: The input json should be a string literal and not null; however, got NULL.; line 1 pos 7 +cannot resolve 'schema_of_json(NULL)' due to data type mismatch: The input json should be a foldable string expression and not null; however, got NULL.; line 1 pos 7 -- !query @@ -343,7 +343,7 @@ SELECT schema_of_json(jsonField) FROM jsonTable struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'schema_of_json(jsontable.`jsonField`)' due to data type mismatch: The input json should be a string literal and not null; however, got jsontable.`jsonField`.; line 1 pos 7 +cannot resolve 'schema_of_json(jsontable.`jsonField`)' due to data type mismatch: The input json should be a foldable string expression and not null; however, got jsontable.`jsonField`.; line 1 pos 7 -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 61f0e138cc35..54dfb4597b04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -200,4 +200,30 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { assert(readback(0).getAs[Row](0).getAs[Date](0).getTime >= 0) } } + + test("support foldable schema by from_csv") { + val options = Map[String, String]().asJava + val schema = concat_ws(",", lit("i int"), lit("s string")) + checkAnswer( + Seq("""1,"a"""").toDS().select(from_csv($"value", schema, options)), + Row(Row(1, "a"))) + + val errMsg = intercept[AnalysisException] { + Seq(("1", "i int")).toDF("csv", "schema") + .select(from_csv($"csv", $"schema", options)).collect() + }.getMessage + assert(errMsg.contains("Schema should be specified in DDL format as a string literal")) + + val errMsg2 = intercept[AnalysisException] { + Seq("1").toDF("csv").select(from_csv($"csv", lit(1), options)).collect() + }.getMessage + assert(errMsg2.contains("The expression '1' is not a valid schema string")) + } + + test("schema_of_csv - infers the schema of foldable CSV string") { + val input = concat_ws(",", lit(0.1), lit(1)) + checkAnswer( + spark.range(1).select(schema_of_csv(input)), + Seq(Row("struct<_c0:double,_c1:int>"))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index fd1e9e309558..ebc2f57a984d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -313,7 +313,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { val errMsg1 = intercept[AnalysisException] { df3.selectExpr("from_json(value, 1)") } - assert(errMsg1.getMessage.startsWith("Schema should be specified in DDL format as a string")) + assert(errMsg1.getMessage.startsWith("The expression '1' is not a valid schema string")) val errMsg2 = intercept[AnalysisException] { df3.selectExpr("""from_json(value, 'time InvalidType')""") } @@ -653,4 +653,25 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { assert(json_tuple_result === len) } } + + test("support foldable schema by from_json") { + val options = Map[String, String]().asJava + val schema = regexp_replace(lit("dpt_org_id INT, dpt_org_city STRING"), "dpt_org_", "") + checkAnswer( + Seq("""{"id":1,"city":"Moscow"}""").toDS().select(from_json($"value", schema, options)), + Row(Row(1, "Moscow"))) + + val errMsg = intercept[AnalysisException] { + Seq(("""{"i":1}""", "i int")).toDF("json", "schema") + .select(from_json($"json", $"schema", options)).collect() + }.getMessage + assert(errMsg.contains("Schema should be specified in DDL format as a string literal")) + } + + test("schema_of_json - infers the schema of foldable JSON string") { + val input = regexp_replace(lit("""{"item_id": 1, "item_price": 0.1}"""), "item_", "") + checkAnswer( + spark.range(1).select(schema_of_json(input)), + Seq(Row("struct"))) + } }