diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 4005087dad05..0978e92dd4f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -155,6 +155,18 @@ package object util { def toPrettySQL(e: Expression): String = usePrettyExpression(e).sql + + def escapeSingleQuotedString(str: String): String = { + val builder = StringBuilder.newBuilder + + str.foreach { + case '\'' => builder ++= s"\\\'" + case ch => builder += ch + } + + builder.toString() + } + /* FIX ME implicit class debugLogging(a: Any) { def debugLogging() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala index 2c18fdcc497f..902cae9150ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -21,6 +21,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} /** * A field inside a StructType. @@ -74,4 +75,16 @@ case class StructField( def getComment(): Option[String] = { if (metadata.contains("comment")) Option(metadata.getString("comment")) else None } + + /** + * Returns a string containing a schema in DDL format. For example, the following value: + * `StructField("eventId", IntegerType)` will be converted to `eventId` INT. + */ + def toDDL: String = { + val comment = getComment() + .map(escapeSingleQuotedString) + .map(" COMMENT '" + _ + "'") + + s"${quoteIdentifier(name)} ${dataType.sql}${comment.getOrElse("")}" + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b13e95f83bc5..c5ca169c955d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} -import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} import org.apache.spark.util.Utils /** @@ -360,6 +360,14 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru s"STRUCT<${fieldTypes.mkString(", ")}>" } + /** + * Returns a string containing a schema in DDL format. For example, the following value: + * `StructType(Seq(StructField("eventId", IntegerType), StructField("s", StringType)))` + * will be converted to `eventId` INT, `s` STRING. + * The returned DDL schema can be used in a table creation. + */ + def toDDL: String = fields.map(_.toDDL).mkString(",") + private[sql] override def simpleString(maxNumberFields: Int): String = { val builder = new StringBuilder val fieldTypes = fields.take(maxNumberFields).map { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index c6ca8bb00542..53a78c94aa6f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.StructType.fromDDL class StructTypeSuite extends SparkFunSuite { @@ -37,4 +38,36 @@ class StructTypeSuite extends SparkFunSuite { val e = intercept[IllegalArgumentException](s.fieldIndex("c")).getMessage assert(e.contains("Available fields: a, b")) } + + test("SPARK-24849: toDDL - simple struct") { + val struct = StructType(Seq(StructField("a", IntegerType))) + + assert(struct.toDDL == "`a` INT") + } + + test("SPARK-24849: round trip toDDL - fromDDL") { + val struct = new StructType().add("a", IntegerType).add("b", StringType) + + assert(fromDDL(struct.toDDL) === struct) + } + + test("SPARK-24849: round trip fromDDL - toDDL") { + val struct = "`a` MAP,`b` INT" + + assert(fromDDL(struct).toDDL === struct) + } + + test("SPARK-24849: toDDL must take into account case of fields.") { + val struct = new StructType() + .add("metaData", new StructType().add("eventId", StringType)) + + assert(struct.toDDL == "`metaData` STRUCT<`eventId`: STRING>") + } + + test("SPARK-24849: toDDL should output field's comment") { + val struct = StructType(Seq( + StructField("b", BooleanType).withComment("Field's comment"))) + + assert(struct.toDDL == """`b` BOOLEAN COMMENT 'Field\'s comment'""") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index ec3961f84bd8..56f48b7dc00e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.Histogram -import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -982,7 +982,7 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman private def showHiveTableHeader(metadata: CatalogTable, builder: StringBuilder): Unit = { val columns = metadata.schema.filterNot { column => metadata.partitionColumnNames.contains(column.name) - }.map(columnToDDLFragment) + }.map(_.toDDL) if (columns.nonEmpty) { builder ++= columns.mkString("(", ", ", ")\n") @@ -994,14 +994,10 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman .foreach(builder.append) } - private def columnToDDLFragment(column: StructField): String = { - val comment = column.getComment().map(escapeSingleQuotedString).map(" COMMENT '" + _ + "'") - s"${quoteIdentifier(column.name)} ${column.dataType.catalogString}${comment.getOrElse("")}" - } private def showHiveTableNonDataColumns(metadata: CatalogTable, builder: StringBuilder): Unit = { if (metadata.partitionColumnNames.nonEmpty) { - val partCols = metadata.partitionSchema.map(columnToDDLFragment) + val partCols = metadata.partitionSchema.map(_.toDDL) builder ++= partCols.mkString("PARTITIONED BY (", ", ", ")\n") } @@ -1072,7 +1068,7 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman private def showDataSourceTableDataColumns( metadata: CatalogTable, builder: StringBuilder): Unit = { - val columns = metadata.schema.fields.map(f => s"${quoteIdentifier(f.name)} ${f.dataType.sql}") + val columns = metadata.schema.fields.map(_.toDDL) builder ++= columns.mkString("(", ", ", ")\n") } @@ -1117,15 +1113,4 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman } } } - - private def escapeSingleQuotedString(str: String): String = { - val builder = StringBuilder.newBuilder - - str.foreach { - case '\'' => builder ++= s"\\\'" - case ch => builder += ch - } - - builder.toString() - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala index 473bbced41b3..34ca79029985 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala @@ -288,6 +288,21 @@ class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSing } } + test("SPARK-24911: keep quotes for nested fields") { + withTable("t1") { + val createTable = "CREATE TABLE `t1`(`a` STRUCT<`b`: STRING>)" + sql(createTable) + val shownDDL = sql(s"SHOW CREATE TABLE t1") + .head() + .getString(0) + .split("\n") + .head + assert(shownDDL == createTable) + + checkCreateTable("t1") + } + } + private def createRawHiveTable(ddl: String): Unit = { hiveContext.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] .client.runSqlHive(ddl)