diff --git a/docs/sql-ref-datatypes.md b/docs/sql-ref-datatypes.md index f27f1a0ca967..0087867a8c7f 100644 --- a/docs/sql-ref-datatypes.md +++ b/docs/sql-ref-datatypes.md @@ -37,6 +37,8 @@ Spark SQL and DataFrames support the following data types: - `DecimalType`: Represents arbitrary-precision signed decimal numbers. Backed internally by `java.math.BigDecimal`. A `BigDecimal` consists of an arbitrary precision integer unscaled value and a 32-bit integer scale. * String type - `StringType`: Represents character string values. + - `VarcharType(length)`: A variant of `StringType` which has a length limitation. Data writing will fail if the input string exceeds the length limitation. Note: this type can only be used in table schema, not functions/operators. + - `CharType(length)`: A variant of `VarcharType(length)` which is fixed length. Reading column of type `CharType(n)` always returns string values of length `n`. Char type column comparison will pad the short one to the longer length. * Binary type - `BinaryType`: Represents byte sequence values. * Boolean type diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 837686420375..af1a2bc9db97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} @@ -3097,7 +3097,12 @@ class Analyzer(override val catalogManager: CatalogManager) val projection = TableOutputResolver.resolveOutputColumns( v2Write.table.name, v2Write.table.output, v2Write.query, v2Write.isByName, conf) if (projection != v2Write.query) { - v2Write.withNewQuery(projection) + val cleanedTable = v2Write.table match { + case r: DataSourceV2Relation => + r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata)) + case other => other + } + v2Write.withNewQuery(projection).withNewTable(cleanedTable) } else { v2Write } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 9998035d65c3..b1a06a3c855e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, TypeUtils} import org.apache.spark.sql.connector.catalog.{SupportsAtomicPartitionManagement, SupportsPartitionManagement, Table} import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} import org.apache.spark.sql.internal.SQLConf @@ -94,6 +94,10 @@ trait CheckAnalysis extends PredicateHelper { case p if p.analyzed => // Skip already analyzed sub-plans + case leaf: LeafNode if leaf.output.map(_.dataType).exists(CharVarcharUtils.hasCharVarchar) => + throw new IllegalStateException( + "[BUG] logical plan should not have output of char/varchar type: " + leaf) + case u: UnresolvedNamespace => u.failAnalysis(s"Namespace not found: ${u.multipartIdentifier.quoted}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 7354d2478b7c..a90de697bc08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -35,7 +35,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case AlterTableAddColumnsStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), cols) => cols.foreach(c => failNullType(c.dataType)) - cols.foreach(c => failCharType(c.dataType)) val changes = cols.map { col => TableChange.addColumn( col.name.toArray, @@ -49,7 +48,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case AlterTableReplaceColumnsStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), cols) => cols.foreach(c => failNullType(c.dataType)) - cols.foreach(c => failCharType(c.dataType)) val changes: Seq[TableChange] = loadTable(catalog, tbl.asIdentifier) match { case Some(table) => // REPLACE COLUMNS deletes all the existing columns and adds new columns specified. @@ -72,7 +70,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case a @ AlterTableAlterColumnStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _) => a.dataType.foreach(failNullType) - a.dataType.foreach(failCharType) val colName = a.column.toArray val typeChange = a.dataType.map { newDataType => TableChange.updateColumnType(colName, newDataType) @@ -145,7 +142,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case c @ CreateTableStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _) => assertNoNullTypeInSchema(c.tableSchema) - assertNoCharTypeInSchema(c.tableSchema) CreateV2Table( catalog.asTableCatalog, tbl.asIdentifier, @@ -173,7 +169,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case c @ ReplaceTableStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _) => assertNoNullTypeInSchema(c.tableSchema) - assertNoCharTypeInSchema(c.tableSchema) ReplaceTable( catalog.asTableCatalog, tbl.asIdentifier, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala index 6d061fce0691..98c6872a47cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAddPartition, AlterTableDropPartition, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement import org.apache.spark.sql.types._ import org.apache.spark.sql.util.PartitioningUtils.normalizePartitionSpec @@ -66,7 +67,8 @@ object ResolvePartitionSpec extends Rule[LogicalPlan] { val partValues = partSchema.map { part => val raw = normalizedSpec.get(part.name).orNull - Cast(Literal.create(raw, StringType), part.dataType, Some(conf.sessionLocalTimeZone)).eval() + val dt = CharVarcharUtils.replaceCharVarcharWithString(part.dataType) + Cast(Literal.create(raw, StringType), dt, Some(conf.sessionLocalTimeZone)).eval() } InternalRow.fromSeq(partValues) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index 4f33ca99c02d..d5c407b47c5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, Attribute, Cast, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.types.DataType @@ -93,19 +94,17 @@ object TableOutputResolver { tableAttr.metadata == queryExpr.metadata) { Some(queryExpr) } else { - // Renaming is needed for handling the following cases like - // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 - // 2) Target tables have column metadata - storeAssignmentPolicy match { + val casted = storeAssignmentPolicy match { case StoreAssignmentPolicy.ANSI => - Some(Alias( - AnsiCast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)), - tableAttr.name)(explicitMetadata = Option(tableAttr.metadata))) + AnsiCast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)) case _ => - Some(Alias( - Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)), - tableAttr.name)(explicitMetadata = Option(tableAttr.metadata))) + Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)) } + val exprWithStrLenCheck = CharVarcharUtils.stringLengthCheck(casted, tableAttr) + // Renaming is needed for handling the following cases like + // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 + // 2) Target tables have column metadata + Some(Alias(exprWithStrLenCheck, tableAttr.name)(explicitMetadata = Some(tableAttr.metadata))) } storeAssignmentPolicy match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 17ab6664df75..a79c26a98598 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, ImplicitCastInputTypes} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} -import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE @@ -473,7 +473,10 @@ class SessionCatalog( val table = formatTableName(name.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Some(db))) - externalCatalog.getTable(db, table) + val t = externalCatalog.getTable(db, table) + // We replace char/varchar with "annotated" string type in the table schema, as the query + // engine doesn't support char/varchar yet. + t.copy(schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(t.schema)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 25423e510157..d173756a45f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -36,8 +36,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} -import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils.IntervalUnit import org.apache.spark.sql.connector.catalog.{SupportsNamespaces, TableCatalog} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition @@ -99,7 +99,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = { - withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList))) + val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema( + StructType(visitColTypeList(ctx.colTypeList))) + withOrigin(ctx)(schema) } def parseRawDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { @@ -2216,7 +2218,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * Create a Spark DataType. */ private def visitSparkDataType(ctx: DataTypeContext): DataType = { - HiveStringType.replaceCharType(typedVisit(ctx)) + CharVarcharUtils.replaceCharVarcharWithString(typedVisit(ctx)) } /** @@ -2291,16 +2293,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg builder.putString("comment", _) } - // Add Hive type string to metadata. - val rawDataType = typedVisit[DataType](ctx.dataType) - val cleanedDataType = HiveStringType.replaceCharType(rawDataType) - if (rawDataType != cleanedDataType) { - builder.putString(HIVE_TYPE_STRING, rawDataType.catalogString) - } - StructField( name = colName.getText, - dataType = cleanedDataType, + dataType = typedVisit[DataType](ctx.dataType), nullable = NULL == null, metadata = builder.build()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index ebf41f6a6e30..4931f0eb2c00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamedRelation, PartitionSpec, Res import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, Unevaluable} import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, ColumnChange} import org.apache.spark.sql.connector.expressions.Transform @@ -45,9 +46,10 @@ trait V2WriteCommand extends Command { table.skipSchemaResolution || (query.output.size == table.output.size && query.output.zip(table.output).forall { case (inAttr, outAttr) => + val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) // names and types must match, nullability must be compatible inAttr.name == outAttr.name && - DataType.equalsIgnoreCompatibleNullability(inAttr.dataType, outAttr.dataType) && + DataType.equalsIgnoreCompatibleNullability(inAttr.dataType, outType) && (outAttr.nullable || !inAttr.nullable) }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala new file mode 100644 index 000000000000..0cbe5abdbbd7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.types._ + +object CharVarcharUtils { + + private val CHAR_VARCHAR_TYPE_STRING_METADATA_KEY = "__CHAR_VARCHAR_TYPE_STRING" + + /** + * Replaces CharType/VarcharType with StringType recursively in the given struct type. If a + * top-level StructField's data type is CharType/VarcharType or has nested CharType/VarcharType, + * this method will add the original type string to the StructField's metadata, so that we can + * re-construct the original data type with CharType/VarcharType later when needed. + */ + def replaceCharVarcharWithStringInSchema(st: StructType): StructType = { + StructType(st.map { field => + if (hasCharVarchar(field.dataType)) { + val metadata = new MetadataBuilder().withMetadata(field.metadata) + .putString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY, field.dataType.sql).build() + field.copy(dataType = replaceCharVarcharWithString(field.dataType), metadata = metadata) + } else { + field + } + }) + } + + /** + * Returns true if the given data type is CharType/VarcharType or has nested CharType/VarcharType. + */ + def hasCharVarchar(dt: DataType): Boolean = { + dt.existsRecursively(f => f.isInstanceOf[CharType] || f.isInstanceOf[VarcharType]) + } + + /** + * Replaces CharType/VarcharType with StringType recursively in the given data type. + */ + def replaceCharVarcharWithString(dt: DataType): DataType = dt match { + case ArrayType(et, nullable) => + ArrayType(replaceCharVarcharWithString(et), nullable) + case MapType(kt, vt, nullable) => + MapType(replaceCharVarcharWithString(kt), replaceCharVarcharWithString(vt), nullable) + case StructType(fields) => + StructType(fields.map { field => + field.copy(dataType = replaceCharVarcharWithString(field.dataType)) + }) + case _: CharType => StringType + case _: VarcharType => StringType + case _ => dt + } + + /** + * Removes the metadata entry that contains the original type string of CharType/VarcharType from + * the given attribute's metadata. + */ + def cleanAttrMetadata(attr: AttributeReference): AttributeReference = { + val cleaned = new MetadataBuilder().withMetadata(attr.metadata) + .remove(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY).build() + attr.withMetadata(cleaned) + } + + /** + * Re-construct the original data type from the type string in the given metadata. + * This is needed when dealing with char/varchar columns/fields. + */ + def getRawType(metadata: Metadata): Option[DataType] = { + if (metadata.contains(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)) { + Some(CatalystSqlParser.parseRawDataType( + metadata.getString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY))) + } else { + None + } + } + + /** + * Returns expressions to apply read-side char type padding for the given attributes. String + * values should be right-padded to N characters if it's from a CHAR(N) column/field. + */ + def charTypePadding(output: Seq[AttributeReference]): Seq[NamedExpression] = { + output.map { attr => + getRawType(attr.metadata).filter { rawType => + rawType.existsRecursively(_.isInstanceOf[CharType]) + }.map { rawType => + Alias(charTypePadding(attr, rawType), attr.name)(explicitMetadata = Some(attr.metadata)) + }.getOrElse(attr) + } + } + + private def charTypePadding(expr: Expression, dt: DataType): Expression = dt match { + case CharType(length) => StringRPad(expr, Literal(length)) + + case StructType(fields) => + val struct = CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) => + Seq(Literal(f.name), charTypePadding(GetStructField(expr, i, Some(f.name)), f.dataType)) + }) + if (expr.nullable) { + If(IsNull(expr), Literal(null, struct.dataType), struct) + } else { + struct + } + + case ArrayType(et, containsNull) => charTypePaddingInArray(expr, et, containsNull) + + case MapType(kt, vt, valueContainsNull) => + val newKeys = charTypePaddingInArray(MapKeys(expr), kt, containsNull = false) + val newValues = charTypePaddingInArray(MapValues(expr), vt, valueContainsNull) + MapFromArrays(newKeys, newValues) + + case _ => expr + } + + private def charTypePaddingInArray( + arr: Expression, et: DataType, containsNull: Boolean): Expression = { + val param = NamedLambdaVariable("x", replaceCharVarcharWithString(et), containsNull) + val func = LambdaFunction(charTypePadding(param, et), Seq(param)) + ArrayTransform(arr, func) + } + + /** + * Returns an expression to apply write-side string length check for the given expression. A + * string value can not exceed N characters if it's written into a CHAR(N)/VARCHAR(N) + * column/field. + */ + def stringLengthCheck(expr: Expression, targetAttr: Attribute): Expression = { + getRawType(targetAttr.metadata).map { rawType => + stringLengthCheck(expr, rawType) + }.getOrElse(expr) + } + + private def raiseError(expr: Expression, typeName: String, length: Int): Expression = { + val errorMsg = Concat(Seq( + Literal("input string '"), + expr, + Literal(s"' exceeds $typeName type length limitation: $length"))) + Cast(RaiseError(errorMsg), StringType) + } + + private def stringLengthCheck(expr: Expression, dt: DataType): Expression = dt match { + case CharType(length) => + val trimmed = StringTrimRight(expr) + // Trailing spaces do not count in the length check. We don't need to retain the trailing + // spaces, as we will pad char type columns/fields at read time. + If( + GreaterThan(Length(trimmed), Literal(length)), + raiseError(expr, "char", length), + trimmed) + + case VarcharType(length) => + val trimmed = StringTrimRight(expr) + // Trailing spaces do not count in the length check. We need to retain the trailing spaces + // (truncate to length N), as there is no read-time padding for varchar type. + // TODO: create a special TrimRight function that can trim to a certain length. + If( + LessThanOrEqual(Length(expr), Literal(length)), + expr, + If( + GreaterThan(Length(trimmed), Literal(length)), + raiseError(expr, "varchar", length), + StringRPad(trimmed, Literal(length)))) + + case StructType(fields) => + val struct = CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) => + Seq(Literal(f.name), stringLengthCheck(GetStructField(expr, i, Some(f.name)), f.dataType)) + }) + if (expr.nullable) { + If(IsNull(expr), Literal(null, struct.dataType), struct) + } else { + struct + } + + case ArrayType(et, containsNull) => stringLengthCheckInArray(expr, et, containsNull) + + case MapType(kt, vt, valueContainsNull) => + val newKeys = stringLengthCheckInArray(MapKeys(expr), kt, containsNull = false) + val newValues = stringLengthCheckInArray(MapValues(expr), vt, valueContainsNull) + MapFromArrays(newKeys, newValues) + + case _ => expr + } + + private def stringLengthCheckInArray( + arr: Expression, et: DataType, containsNull: Boolean): Expression = { + val param = NamedLambdaVariable("x", replaceCharVarcharWithString(et), containsNull) + val func = LambdaFunction(stringLengthCheck(param, et), Seq(param)) + ArrayTransform(arr, func) + } + + /** + * Return expressions to apply char type padding for the string comparison between the given + * attributes. When comparing two char type columns/fields, we need to pad the shorter one to + * the longer length. + */ + def addPaddingInStringComparison(attrs: Seq[Attribute]): Seq[Expression] = { + val rawTypes = attrs.map(attr => getRawType(attr.metadata)) + if (rawTypes.exists(_.isEmpty)) { + attrs + } else { + val typeWithTargetCharLength = rawTypes.map(_.get).reduce(typeWithWiderCharLength) + attrs.zip(rawTypes.map(_.get)).map { case (attr, rawType) => + padCharToTargetLength(attr, rawType, typeWithTargetCharLength).getOrElse(attr) + } + } + } + + private def typeWithWiderCharLength(type1: DataType, type2: DataType): DataType = { + (type1, type2) match { + case (CharType(len1), CharType(len2)) => + CharType(math.max(len1, len2)) + case (StructType(fields1), StructType(fields2)) => + assert(fields1.length == fields2.length) + StructType(fields1.zip(fields2).map { case (left, right) => + StructField("", typeWithWiderCharLength(left.dataType, right.dataType)) + }) + case (ArrayType(et1, _), ArrayType(et2, _)) => + ArrayType(typeWithWiderCharLength(et1, et2)) + case _ => NullType + } + } + + private def padCharToTargetLength( + expr: Expression, + rawType: DataType, + typeWithTargetCharLength: DataType): Option[Expression] = { + (rawType, typeWithTargetCharLength) match { + case (CharType(len), CharType(target)) if target > len => + Some(StringRPad(expr, Literal(target))) + + case (StructType(fields), StructType(targets)) => + assert(fields.length == targets.length) + var i = 0 + var needPadding = false + val createStructExprs = mutable.ArrayBuffer.empty[Expression] + while (i < fields.length) { + val field = fields(i) + val fieldExpr = GetStructField(expr, i, Some(field.name)) + val padded = padCharToTargetLength(fieldExpr, field.dataType, targets(i).dataType) + needPadding = padded.isDefined + createStructExprs += Literal(field.name) + createStructExprs += padded.getOrElse(fieldExpr) + i += 1 + } + if (needPadding) Some(CreateNamedStruct(createStructExprs.toSeq)) else None + + case (ArrayType(et, containsNull), ArrayType(target, _)) => + val param = NamedLambdaVariable("x", replaceCharVarcharWithString(et), containsNull) + padCharToTargetLength(param, et, target).map { padded => + val func = LambdaFunction(padded, Seq(param)) + ArrayTransform(expr, func) + } + + // We don't handle MapType here as it's not comparable. + + case _ => None + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index b6dc4f61c858..02db2293ec64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -24,11 +24,10 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NamedRelation, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, UnresolvedV2Relation} -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, CreateTableAsSelectStatement, CreateTableStatement, ReplaceTableAsSelectStatement, ReplaceTableStatement, SerdeInfo} import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.types.{ArrayType, DataType, HIVE_TYPE_STRING, HiveStringType, MapType, NullType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, NullType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils @@ -379,21 +378,6 @@ private[sql] object CatalogV2Util { .asTableCatalog } - def failCharType(dt: DataType): Unit = { - if (HiveStringType.containsCharType(dt)) { - throw new AnalysisException( - "Cannot use CHAR type in non-Hive-Serde tables, please use STRING type instead.") - } - } - - def assertNoCharTypeInSchema(schema: StructType): Unit = { - schema.foreach { f => - if (f.metadata.contains(HIVE_TYPE_STRING)) { - failCharType(CatalystSqlParser.parseRawDataType(f.metadata.getString(HIVE_TYPE_STRING))) - } - } - } - def failNullType(dt: DataType): Unit = { def containsNullType(dt: DataType): Boolean = dt match { case ArrayType(et, _) => containsNullType(et) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index f541411daeff..4debdd380e6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils} import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, MetadataColumn, SupportsMetadataColumns, Table, TableCapability} import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportStatistics} import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream} @@ -171,8 +171,10 @@ object DataSourceV2Relation { catalog: Option[CatalogPlugin], identifier: Option[Identifier], options: CaseInsensitiveStringMap): DataSourceV2Relation = { - val output = table.schema().toAttributes - DataSourceV2Relation(table, output, catalog, identifier, options) + // The v2 source may return schema containing char/varchar type. We replace char/varchar + // with "annotated" string type here as the query engine doesn't support char/varchar yet. + val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(table.schema) + DataSourceV2Relation(table, schema.toAttributes, catalog, identifier, options) } def create( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala new file mode 100644 index 000000000000..67ab1cc2f332 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.Experimental +import org.apache.spark.unsafe.types.UTF8String + +@Experimental +case class CharType(length: Int) extends AtomicType { + require(length >= 0, "The length of char type cannot be negative.") + + private[sql] type InternalType = UTF8String + @transient private[sql] lazy val tag = typeTag[InternalType] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + override def defaultSize: Int = length + override def typeName: String = s"char($length)" + override def toString: String = s"CharType($length)" + private[spark] override def asNullable: CharType = this +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 7556a19f0d31..e4ee6eb377a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -124,13 +124,15 @@ abstract class DataType extends AbstractDataType { object DataType { private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r + private val CHAR_TYPE = """char\(\s*(\d+)\s*\)""".r + private val VARCHAR_TYPE = """varchar\(\s*(\d+)\s*\)""".r def fromDDL(ddl: String): DataType = { parseTypeWithFallback( ddl, CatalystSqlParser.parseDataType, "Cannot parse the data type: ", - fallbackParser = CatalystSqlParser.parseTableSchema) + fallbackParser = str => CatalystSqlParser.parseTableSchema(str)) } /** @@ -166,7 +168,7 @@ object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) - private val nonDecimalNameToType = { + private val otherTypes = { Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType) .map(t => t.typeName -> t).toMap @@ -177,7 +179,9 @@ object DataType { name match { case "decimal" => DecimalType.USER_DEFAULT case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) - case other => nonDecimalNameToType.getOrElse( + case CHAR_TYPE(length) => CharType(length.toInt) + case VARCHAR_TYPE(length) => VarcharType(length.toInt) + case other => otherTypes.getOrElse( other, throw new IllegalArgumentException( s"Failed to convert the JSON string '$name' to a data type.")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala deleted file mode 100644 index a29f49ad14a7..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.types - -import scala.math.Ordering -import scala.reflect.runtime.universe.typeTag - -import org.apache.spark.unsafe.types.UTF8String - -/** - * A hive string type for compatibility. These datatypes should only used for parsing, - * and should NOT be used anywhere else. Any instance of these data types should be - * replaced by a [[StringType]] before analysis. - */ -sealed abstract class HiveStringType extends AtomicType { - private[sql] type InternalType = UTF8String - - private[sql] val ordering = implicitly[Ordering[InternalType]] - - @transient private[sql] lazy val tag = typeTag[InternalType] - - override def defaultSize: Int = length - - private[spark] override def asNullable: HiveStringType = this - - def length: Int -} - -object HiveStringType { - def replaceCharType(dt: DataType): DataType = dt match { - case ArrayType(et, nullable) => - ArrayType(replaceCharType(et), nullable) - case MapType(kt, vt, nullable) => - MapType(replaceCharType(kt), replaceCharType(vt), nullable) - case StructType(fields) => - StructType(fields.map { field => - field.copy(dataType = replaceCharType(field.dataType)) - }) - case _: HiveStringType => StringType - case _ => dt - } - - def containsCharType(dt: DataType): Boolean = dt match { - case ArrayType(et, _) => containsCharType(et) - case MapType(kt, vt, _) => containsCharType(kt) || containsCharType(vt) - case StructType(fields) => fields.exists(f => containsCharType(f.dataType)) - case _ => dt.isInstanceOf[CharType] - } -} - -/** - * Hive char type. Similar to other HiveStringType's, these datatypes should only used for - * parsing, and should NOT be used anywhere else. Any instance of these data types should be - * replaced by a [[StringType]] before analysis. - */ -case class CharType(length: Int) extends HiveStringType { - override def simpleString: String = s"char($length)" -} - -/** - * Hive varchar type. Similar to other HiveStringType's, these datatypes should only used for - * parsing, and should NOT be used anywhere else. Any instance of these data types should be - * replaced by a [[StringType]] before analysis. - */ -case class VarcharType(length: Int) extends HiveStringType { - override def simpleString: String = s"varchar($length)" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala new file mode 100644 index 000000000000..8d78640c1e12 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.Experimental +import org.apache.spark.unsafe.types.UTF8String + +@Experimental +case class VarcharType(length: Int) extends AtomicType { + require(length >= 0, "The length of varchar type cannot be negative.") + + private[sql] type InternalType = UTF8String + @transient private[sql] lazy val tag = typeTag[InternalType] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + override def defaultSize: Int = length + override def typeName: String = s"varchar($length)" + override def toString: String = s"CharType($length)" + private[spark] override def asNullable: VarcharType = this +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala index f29cbc2069e3..346a51ea10c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala @@ -21,12 +21,4 @@ package org.apache.spark.sql * Contains a type system for attributes produced by relations, including complex types like * structs, arrays and maps. */ -package object types { - /** - * Metadata key used to store the raw hive type string in the metadata of StructField. This - * is relevant for datatypes that do not have a direct Spark SQL counterpart, such as CHAR and - * VARCHAR. We need to preserve the original type in order to invoke the correct object - * inspector in Hive. - */ - val HIVE_TYPE_STRING = "HIVE_TYPE_STRING" -} +package object types diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f0a24d4a5604..0afa811e5d59 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.TimeZone +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -41,9 +42,11 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.connector.InMemoryTable +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ - +import org.apache.spark.sql.util.CaseInsensitiveStringMap class AnalysisSuite extends AnalysisTest with Matchers { import org.apache.spark.sql.catalyst.analysis.TestRelations._ @@ -55,6 +58,19 @@ class AnalysisSuite extends AnalysisTest with Matchers { } } + test("fail if a leaf node has char/varchar type output") { + val schema1 = new StructType().add("c", CharType(5)) + val schema2 = new StructType().add("c", VarcharType(5)) + val schema3 = new StructType().add("c", ArrayType(CharType(5))) + Seq(schema1, schema2, schema3).foreach { schema => + val table = new InMemoryTable("t", schema, Array.empty, Map.empty[String, String].asJava) + intercept[IllegalStateException] { + DataSourceV2Relation( + table, schema.toAttributes, None, None, CaseInsensitiveStringMap.empty()).analyze + } + } + } + test("union project *") { val plan = (1 to 120) .map(_ => testRelation) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala index 6803fc307f91..95851d44b474 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.types._ class TableSchemaParserSuite extends SparkFunSuite { @@ -57,11 +58,6 @@ class TableSchemaParserSuite extends SparkFunSuite { |anotherArray:Array> """.stripMargin.replace("\n", "") - val builder = new MetadataBuilder - builder.putString(HIVE_TYPE_STRING, - "struct," + - "MAP:map,arrAy:array,anotherArray:array>") - val expectedDataType = StructType( StructField("complexStructCol", StructType( @@ -69,13 +65,12 @@ class TableSchemaParserSuite extends SparkFunSuite { StructType( StructField("deciMal", DecimalType.USER_DEFAULT) :: StructField("anotherDecimal", DecimalType(5, 2)) :: Nil)) :: - StructField("MAP", MapType(TimestampType, StringType)) :: + StructField("MAP", MapType(TimestampType, VarcharType(10))) :: StructField("arrAy", ArrayType(DoubleType)) :: - StructField("anotherArray", ArrayType(StringType)) :: Nil), - nullable = true, - builder.build()) :: Nil) + StructField("anotherArray", ArrayType(CharType(9))) :: Nil)) :: Nil) - assert(parse(tableSchemaString) === expectedDataType) + assert(parse(tableSchemaString) === + CharVarcharUtils.replaceCharVarcharWithStringInSchema(expectedDataType)) } // Negative cases diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index ffff00b54f1b..cfb044b428e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -28,7 +28,7 @@ import org.scalatest.Assertions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow} -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.connector.read._ @@ -116,11 +116,12 @@ class InMemoryTable( } } + val cleanedSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema) partitioning.map { case IdentityTransform(ref) => - extractor(ref.fieldNames, schema, row)._1 + extractor(ref.fieldNames, cleanedSchema, row)._1 case YearsTransform(ref) => - extractor(ref.fieldNames, schema, row) match { + extractor(ref.fieldNames, cleanedSchema, row) match { case (days: Int, DateType) => ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) case (micros: Long, TimestampType) => @@ -130,7 +131,7 @@ class InMemoryTable( throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } case MonthsTransform(ref) => - extractor(ref.fieldNames, schema, row) match { + extractor(ref.fieldNames, cleanedSchema, row) match { case (days: Int, DateType) => ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) case (micros: Long, TimestampType) => @@ -140,7 +141,7 @@ class InMemoryTable( throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } case DaysTransform(ref) => - extractor(ref.fieldNames, schema, row) match { + extractor(ref.fieldNames, cleanedSchema, row) match { case (days, DateType) => days case (micros: Long, TimestampType) => @@ -149,14 +150,14 @@ class InMemoryTable( throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } case HoursTransform(ref) => - extractor(ref.fieldNames, schema, row) match { + extractor(ref.fieldNames, cleanedSchema, row) match { case (micros: Long, TimestampType) => ChronoUnit.HOURS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)) case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } case BucketTransform(numBuckets, ref) => - val (value, dataType) = extractor(ref.fieldNames, schema, row) + val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row) val valueHashCode = if (value == null) 0 else value.hashCode ((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala index 7a9a7f52ff8f..da5cfab8be3c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala @@ -28,7 +28,7 @@ class CatalogV2UtilSuite extends SparkFunSuite { val testCatalog = mock(classOf[TableCatalog]) val ident = mock(classOf[Identifier]) val table = mock(classOf[Table]) - when(table.schema()).thenReturn(mock(classOf[StructType])) + when(table.schema()).thenReturn(new StructType().add("i", "int")) when(testCatalog.loadTable(ident)).thenReturn(table) val r = CatalogV2Util.loadRelation(testCatalog, ident) assert(r.isDefined) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index c164835c753e..b3e403ffa738 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit @@ -1181,7 +1181,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def cast(to: DataType): Column = withExpr { Cast(expr, to) } + def cast(to: DataType): Column = withExpr { + Cast(expr, CharVarcharUtils.replaceCharVarcharWithString(to)) + } /** * Casts the column to a different data type, using the canonical string representation diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b26bc6441b6c..49b3335bf176 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailureSafeParser} +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, FailureSafeParser} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsCatalogOptions, SupportsRead} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.command.DDLUtils @@ -73,7 +73,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def schema(schema: StructType): DataFrameReader = { - this.userSpecifiedSchema = Option(schema) + this.userSpecifiedSchema = Option(CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema)) this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index f49caf7f04a2..c1c202aced6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.HiveSerDe -import org.apache.spark.sql.types.{HIVE_TYPE_STRING, HiveStringType, MetadataBuilder, StructField, StructType} +import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType} /** * Resolves catalogs from the multi-part identifiers in SQL statements, and convert the statements @@ -51,9 +51,6 @@ class ResolveSessionCatalog( cols.foreach(c => failNullType(c.dataType)) loadTable(catalog, tbl.asIdentifier).collect { case v1Table: V1Table => - if (!DDLUtils.isHiveTable(v1Table.v1Table)) { - cols.foreach(c => failCharType(c.dataType)) - } cols.foreach { c => assertTopLevelColumn(c.name, "AlterTableAddColumnsCommand") if (!c.nullable) { @@ -63,7 +60,6 @@ class ResolveSessionCatalog( } AlterTableAddColumnsCommand(tbl.asTableIdentifier, cols.map(convertToStructField)) }.getOrElse { - cols.foreach(c => failCharType(c.dataType)) val changes = cols.map { col => TableChange.addColumn( col.name.toArray, @@ -82,7 +78,6 @@ class ResolveSessionCatalog( case Some(_: V1Table) => throw new AnalysisException("REPLACE COLUMNS is only supported with v2 tables.") case Some(table) => - cols.foreach(c => failCharType(c.dataType)) // REPLACE COLUMNS deletes all the existing columns and adds new columns specified. val deleteChanges = table.schema.fieldNames.map { name => TableChange.deleteColumn(Array(name)) @@ -105,10 +100,6 @@ class ResolveSessionCatalog( a.dataType.foreach(failNullType) loadTable(catalog, tbl.asIdentifier).collect { case v1Table: V1Table => - if (!DDLUtils.isHiveTable(v1Table.v1Table)) { - a.dataType.foreach(failCharType) - } - if (a.column.length > 1) { throw new AnalysisException( "ALTER COLUMN with qualified column is only supported with v2 tables.") @@ -134,19 +125,13 @@ class ResolveSessionCatalog( s"Available: ${v1Table.schema.fieldNames.mkString(", ")}") } } - // Add Hive type string to metadata. - val cleanedDataType = HiveStringType.replaceCharType(dataType) - if (dataType != cleanedDataType) { - builder.putString(HIVE_TYPE_STRING, dataType.catalogString) - } val newColumn = StructField( colName, - cleanedDataType, + dataType, nullable = true, builder.build()) AlterTableChangeColumnCommand(tbl.asTableIdentifier, colName, newColumn) }.getOrElse { - a.dataType.foreach(failCharType) val colName = a.column.toArray val typeChange = a.dataType.map { newDataType => TableChange.updateColumnType(colName, newDataType) @@ -271,16 +256,12 @@ class ResolveSessionCatalog( val (storageFormat, provider) = getStorageFormatAndProvider( c.provider, c.options, c.location, c.serde, ctas = false) if (!isV2Provider(provider)) { - if (!DDLUtils.isHiveTable(Some(provider))) { - assertNoCharTypeInSchema(c.tableSchema) - } val tableDesc = buildCatalogTable(tbl.asTableIdentifier, c.tableSchema, c.partitioning, c.bucketSpec, c.properties, provider, c.location, c.comment, storageFormat, c.external) val mode = if (c.ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTable(tableDesc, mode, None) } else { - assertNoCharTypeInSchema(c.tableSchema) CreateV2Table( catalog.asTableCatalog, tbl.asIdentifier, @@ -305,7 +286,6 @@ class ResolveSessionCatalog( val mode = if (c.ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTable(tableDesc, mode, Some(c.asSelect)) } else { - assertNoCharTypeInSchema(c.schema) CreateTableAsSelect( catalog.asTableCatalog, tbl.asIdentifier, @@ -332,7 +312,6 @@ class ResolveSessionCatalog( if (!isV2Provider(provider)) { throw new AnalysisException("REPLACE TABLE is only supported with v2 tables.") } else { - assertNoCharTypeInSchema(c.tableSchema) ReplaceTable( catalog.asTableCatalog, tbl.asIdentifier, @@ -771,17 +750,7 @@ class ResolveSessionCatalog( private def convertToStructField(col: QualifiedColType): StructField = { val builder = new MetadataBuilder col.comment.foreach(builder.putString("comment", _)) - - val cleanedDataType = HiveStringType.replaceCharType(col.dataType) - if (col.dataType != cleanedDataType) { - builder.putString(HIVE_TYPE_STRING, col.dataType.catalogString) - } - - StructField( - col.name.head, - cleanedDataType, - nullable = true, - builder.build()) + StructField(col.name.head, col.dataType, nullable = true, builder.build()) } private def isV2Provider(provider: String): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala new file mode 100644 index 000000000000..35bb86f178eb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryComparison, Expression, In, Literal, StringRPad} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.{CharType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * This rule applies char type padding in two places: + * 1. When reading values from column/field of type CHAR(N), right-pad the values to length N. + * 2. When comparing char type column/field with string literal or char type column/field, + * right-pad the shorter one to the longer length. + */ +object ApplyCharTypePadding extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = { + val padded = plan.resolveOperatorsUpWithNewOutput { + case r: LogicalRelation => + val projectList = CharVarcharUtils.charTypePadding(r.output) + if (projectList == r.output) { + r -> Nil + } else { + val cleanedOutput = r.output.map(CharVarcharUtils.cleanAttrMetadata) + val padded = Project(projectList, r.copy(output = cleanedOutput)) + padded -> r.output.zip(padded.output) + } + + case r: DataSourceV2Relation => + val projectList = CharVarcharUtils.charTypePadding(r.output) + if (projectList == r.output) { + r -> Nil + } else { + val cleanedOutput = r.output.map(CharVarcharUtils.cleanAttrMetadata) + val padded = Project(projectList, r.copy(output = cleanedOutput)) + padded -> r.output.zip(padded.output) + } + + case r: HiveTableRelation => + val projectList = CharVarcharUtils.charTypePadding(r.output) + if (projectList == r.output) { + r -> Nil + } else { + val cleanedDataCols = r.dataCols.map(CharVarcharUtils.cleanAttrMetadata) + val cleanedPartCols = r.partitionCols.map(CharVarcharUtils.cleanAttrMetadata) + val padded = Project(projectList, + r.copy(dataCols = cleanedDataCols, partitionCols = cleanedPartCols)) + padded -> r.output.zip(padded.output) + } + } + + padded.resolveOperatorsUp { + case operator if operator.resolved => operator.transformExpressionsUp { + // String literal is treated as char type when it's compared to a char type column. + // We should pad the shorter one to the longer length. + case b @ BinaryComparison(attr: Attribute, lit) if lit.foldable => + padAttrLitCmp(attr, lit).map { newChildren => + b.withNewChildren(newChildren) + }.getOrElse(b) + + case b @ BinaryComparison(lit, attr: Attribute) if lit.foldable => + padAttrLitCmp(attr, lit).map { newChildren => + b.withNewChildren(newChildren.reverse) + }.getOrElse(b) + + case i @ In(attr: Attribute, list) + if attr.dataType == StringType && list.forall(_.foldable) => + CharVarcharUtils.getRawType(attr.metadata).flatMap { + case CharType(length) => + val literalCharLengths = list.map(_.eval().asInstanceOf[UTF8String].numChars()) + val targetLen = (length +: literalCharLengths).max + Some(i.copy( + value = addPadding(attr, length, targetLen), + list = list.zip(literalCharLengths).map { + case (lit, charLength) => addPadding(lit, charLength, targetLen) + })) + case _ => None + }.getOrElse(i) + + // For char type column or inner field comparison, pad the shorter one to the longer length. + case b @ BinaryComparison(left: Attribute, right: Attribute) => + b.withNewChildren(CharVarcharUtils.addPaddingInStringComparison(Seq(left, right))) + + case i @ In(attr: Attribute, list) if list.forall(_.isInstanceOf[Attribute]) => + val newChildren = CharVarcharUtils.addPaddingInStringComparison( + attr +: list.map(_.asInstanceOf[Attribute])) + i.copy(value = newChildren.head, list = newChildren.tail) + } + } + } + + private def padAttrLitCmp(attr: Attribute, lit: Expression): Option[Seq[Expression]] = { + if (attr.dataType == StringType) { + CharVarcharUtils.getRawType(attr.metadata).flatMap { + case CharType(length) => + val str = lit.eval().asInstanceOf[UTF8String] + val stringLitLen = str.numChars() + if (length < stringLitLen) { + Some(Seq(StringRPad(attr, Literal(stringLitLen)), lit)) + } else if (length > stringLitLen) { + Some(Seq(attr, StringRPad(lit, Literal(length)))) + } else { + None + } + case _ => None + } + } else { + None + } + } + + private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = { + if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 33a3486bf6f6..8c61c8cd4f52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils} import org.apache.spark.sql.sources.BaseRelation /** @@ -69,9 +69,17 @@ case class LogicalRelation( } object LogicalRelation { - def apply(relation: BaseRelation, isStreaming: Boolean = false): LogicalRelation = - LogicalRelation(relation, relation.schema.toAttributes, None, isStreaming) + def apply(relation: BaseRelation, isStreaming: Boolean = false): LogicalRelation = { + // The v1 source may return schema containing char/varchar type. We replace char/varchar + // with "annotated" string type here as the query engine doesn't support char/varchar yet. + val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(relation.schema) + LogicalRelation(relation, schema.toAttributes, None, isStreaming) + } - def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation = - LogicalRelation(relation, relation.schema.toAttributes, Some(table), false) + def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation = { + // The v1 source may return schema containing char/varchar type. We replace char/varchar + // with "annotated" string type here as the query engine doesn't support char/varchar yet. + val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(relation.schema) + LogicalRelation(relation, schema.toAttributes, Some(table), false) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 78f31fb80ecf..5dd0d2bd7483 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} @@ -761,17 +761,10 @@ object JdbcUtils extends Logging { schema: StructType, caseSensitive: Boolean, createTableColumnTypes: String): Map[String, String] = { - def typeName(f: StructField): String = { - // char/varchar gets translated to string type. Real data type specified by the user - // is available in the field metadata as HIVE_TYPE_STRING - if (f.metadata.contains(HIVE_TYPE_STRING)) { - f.metadata.getString(HIVE_TYPE_STRING) - } else { - f.dataType.catalogString - } - } - - val userSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes) + val parsedSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes) + val userSchema = StructType(parsedSchema.map { field => + field.copy(dataType = CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType)) + }) val nameEquality = if (caseSensitive) { org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution } else { @@ -791,7 +784,7 @@ object JdbcUtils extends Logging { } } - val userSchemaMap = userSchema.fields.map(f => f.name -> typeName(f)).toMap + val userSchemaMap = userSchema.fields.map(f => f.name -> f.dataType.catalogString).toMap if (caseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index ce8edce6f08d..2208e930f6b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.internal.SQLConf @@ -110,7 +111,8 @@ object PushDownUtils extends PredicateHelper { schema: StructType, relation: DataSourceV2Relation): Seq[AttributeReference] = { val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap - schema.toAttributes.map { + val cleaned = CharVarcharUtils.replaceCharVarcharWithString(schema).asInstanceOf[StructType] + cleaned.toAttributes.map { // we have to keep the attribute id during transformation a => a.withExprId(nameToAttr(a.name).exprId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 538a5408723b..a89a5de3b7e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -189,6 +189,7 @@ abstract class BaseSessionStateBuilder( PreprocessTableCreation(session) +: PreprocessTableInsertion +: DataSourceAnalysis +: + ApplyCharTypePadding +: customPostHocResolutionRules override val extendedCheckRules: Seq[LogicalPlan => Unit] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 9bc4acd49a98..4e755682242d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.connector.catalog.{SupportsRead, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.command.DDLUtils @@ -64,7 +64,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * @since 2.0.0 */ def schema(schema: StructType): DataStreamReader = { - this.userSpecifiedSchema = Option(schema) + this.userSpecifiedSchema = Option(CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema)) this } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala new file mode 100644 index 000000000000..abb13270d20e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -0,0 +1,505 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.{InMemoryPartitionTableCatalog, SchemaRequiredDataSource} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.SimpleInsertSource +import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} +import org.apache.spark.sql.types.{ArrayType, CharType, DataType, MapType, StringType, StructField, StructType} + +// The base trait for char/varchar tests that need to be run with different table implementations. +trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { + + def format: String + + def checkColType(f: StructField, dt: DataType): Unit = { + assert(f.dataType == CharVarcharUtils.replaceCharVarcharWithString(dt)) + assert(CharVarcharUtils.getRawType(f.metadata) == Some(dt)) + } + + test("char type values should be padded: top-level columns") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c CHAR(5)) USING $format") + sql("INSERT INTO t VALUES ('1', 'a')") + checkAnswer(spark.table("t"), Row("1", "a" + " " * 4)) + checkColType(spark.table("t").schema(1), CharType(5)) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) + } + } + + test("char type values should be padded: partitioned columns") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c CHAR(5)) USING $format PARTITIONED BY (c)") + sql("INSERT INTO t VALUES ('1', 'a')") + checkAnswer(spark.table("t"), Row("1", "a" + " " * 4)) + checkColType(spark.table("t").schema(1), CharType(5)) + + sql("ALTER TABLE t DROP PARTITION(c='a')") + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) + } + } + + test("char type values should be padded: nested in struct") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c STRUCT) USING $format") + sql("INSERT INTO t VALUES ('1', struct('a'))") + checkAnswer(spark.table("t"), Row("1", Row("a" + " " * 4))) + checkColType(spark.table("t").schema(1), new StructType().add("c", CharType(5))) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) + sql("INSERT OVERWRITE t VALUES ('1', struct(null))") + checkAnswer(spark.table("t"), Row("1", Row(null))) + } + } + + test("char type values should be padded: nested in array") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c ARRAY) USING $format") + sql("INSERT INTO t VALUES ('1', array('a', 'ab'))") + checkAnswer(spark.table("t"), Row("1", Seq("a" + " " * 4, "ab" + " " * 3))) + checkColType(spark.table("t").schema(1), ArrayType(CharType(5))) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) + sql("INSERT OVERWRITE t VALUES ('1', array(null))") + checkAnswer(spark.table("t"), Row("1", Seq(null))) + } + } + + test("char type values should be padded: nested in map key") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c MAP) USING $format") + sql("INSERT INTO t VALUES ('1', map('a', 'ab'))") + checkAnswer(spark.table("t"), Row("1", Map(("a" + " " * 4, "ab")))) + checkColType(spark.table("t").schema(1), MapType(CharType(5), StringType)) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) + } + } + + test("char type values should be padded: nested in map value") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c MAP) USING $format") + sql("INSERT INTO t VALUES ('1', map('a', 'ab'))") + checkAnswer(spark.table("t"), Row("1", Map(("a", "ab" + " " * 3)))) + checkColType(spark.table("t").schema(1), MapType(StringType, CharType(5))) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) + sql("INSERT OVERWRITE t VALUES ('1', map('a', null))") + checkAnswer(spark.table("t"), Row("1", Map("a" -> null))) + } + } + + test("char type values should be padded: nested in both map key and value") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c MAP) USING $format") + sql("INSERT INTO t VALUES ('1', map('a', 'ab'))") + checkAnswer(spark.table("t"), Row("1", Map(("a" + " " * 4, "ab" + " " * 8)))) + checkColType(spark.table("t").schema(1), MapType(CharType(5), CharType(10))) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) + } + } + + test("char type values should be padded: nested in struct of array") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c STRUCT>) USING $format") + sql("INSERT INTO t VALUES ('1', struct(array('a', 'ab')))") + checkAnswer(spark.table("t"), Row("1", Row(Seq("a" + " " * 4, "ab" + " " * 3)))) + checkColType(spark.table("t").schema(1), + new StructType().add("c", ArrayType(CharType(5)))) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) + sql("INSERT OVERWRITE t VALUES ('1', struct(null))") + checkAnswer(spark.table("t"), Row("1", Row(null))) + sql("INSERT OVERWRITE t VALUES ('1', struct(array(null)))") + checkAnswer(spark.table("t"), Row("1", Row(Seq(null)))) + } + } + + test("char type values should be padded: nested in array of struct") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c ARRAY>) USING $format") + sql("INSERT INTO t VALUES ('1', array(struct('a'), struct('ab')))") + checkAnswer(spark.table("t"), Row("1", Seq(Row("a" + " " * 4), Row("ab" + " " * 3)))) + checkColType(spark.table("t").schema(1), + ArrayType(new StructType().add("c", CharType(5)))) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) + sql("INSERT OVERWRITE t VALUES ('1', array(null))") + checkAnswer(spark.table("t"), Row("1", Seq(null))) + sql("INSERT OVERWRITE t VALUES ('1', array(struct(null)))") + checkAnswer(spark.table("t"), Row("1", Seq(Row(null)))) + } + } + + test("char type values should be padded: nested in array of array") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c ARRAY>) USING $format") + sql("INSERT INTO t VALUES ('1', array(array('a', 'ab')))") + checkAnswer(spark.table("t"), Row("1", Seq(Seq("a" + " " * 4, "ab" + " " * 3)))) + checkColType(spark.table("t").schema(1), ArrayType(ArrayType(CharType(5)))) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) + sql("INSERT OVERWRITE t VALUES ('1', array(null))") + checkAnswer(spark.table("t"), Row("1", Seq(null))) + sql("INSERT OVERWRITE t VALUES ('1', array(array(null)))") + checkAnswer(spark.table("t"), Row("1", Seq(Seq(null)))) + } + } + + private def testTableWrite(f: String => Unit): Unit = { + withTable("t") { f("char") } + withTable("t") { f("varchar") } + } + + test("length check for input string values: top-level columns") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c $typeName(5)) USING $format") + sql("INSERT INTO t VALUES (null)") + checkAnswer(spark.table("t"), Row(null)) + val e = intercept[SparkException](sql("INSERT INTO t VALUES ('123456')")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: partitioned columns") { + // DS V2 doesn't support partitioned table. + if (!conf.contains(SQLConf.DEFAULT_CATALOG.key)) { + testTableWrite { typeName => + sql(s"CREATE TABLE t(i INT, c $typeName(5)) USING $format PARTITIONED BY (c)") + sql("INSERT INTO t VALUES (1, null)") + checkAnswer(spark.table("t"), Row(1, null)) + val e = intercept[SparkException](sql("INSERT INTO t VALUES (1, '123456')")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + } + + test("length check for input string values: nested in struct") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c STRUCT) USING $format") + sql("INSERT INTO t SELECT struct(null)") + checkAnswer(spark.table("t"), Row(Row(null))) + val e = intercept[SparkException](sql("INSERT INTO t SELECT struct('123456')")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in array") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c ARRAY<$typeName(5)>) USING $format") + sql("INSERT INTO t VALUES (array(null))") + checkAnswer(spark.table("t"), Row(Seq(null))) + val e = intercept[SparkException](sql("INSERT INTO t VALUES (array('a', '123456'))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in map key") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c MAP<$typeName(5), STRING>) USING $format") + val e = intercept[SparkException](sql("INSERT INTO t VALUES (map('123456', 'a'))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in map value") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c MAP) USING $format") + sql("INSERT INTO t VALUES (map('a', null))") + checkAnswer(spark.table("t"), Row(Map("a" -> null))) + val e = intercept[SparkException](sql("INSERT INTO t VALUES (map('a', '123456'))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in both map key and value") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c MAP<$typeName(5), $typeName(5)>) USING $format") + val e1 = intercept[SparkException](sql("INSERT INTO t VALUES (map('123456', 'a'))")) + assert(e1.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + val e2 = intercept[SparkException](sql("INSERT INTO t VALUES (map('a', '123456'))")) + assert(e2.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in struct of array") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c STRUCT>) USING $format") + sql("INSERT INTO t SELECT struct(array(null))") + checkAnswer(spark.table("t"), Row(Row(Seq(null)))) + val e = intercept[SparkException](sql("INSERT INTO t SELECT struct(array('123456'))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in array of struct") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c ARRAY>) USING $format") + sql("INSERT INTO t VALUES (array(struct(null)))") + checkAnswer(spark.table("t"), Row(Seq(Row(null)))) + val e = intercept[SparkException](sql("INSERT INTO t VALUES (array(struct('123456')))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in array of array") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c ARRAY>) USING $format") + sql("INSERT INTO t VALUES (array(array(null)))") + checkAnswer(spark.table("t"), Row(Seq(Seq(null)))) + val e = intercept[SparkException](sql("INSERT INTO t VALUES (array(array('123456')))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: with trailing spaces") { + withTable("t") { + sql(s"CREATE TABLE t(c1 CHAR(5), c2 VARCHAR(5)) USING $format") + sql("INSERT INTO t VALUES ('12 ', '12 ')") + sql("INSERT INTO t VALUES ('1234 ', '1234 ')") + checkAnswer(spark.table("t"), Seq( + Row("12" + " " * 3, "12 "), + Row("1234 ", "1234 "))) + } + } + + test("length check for input string values: with implicit cast") { + withTable("t") { + sql(s"CREATE TABLE t(c1 CHAR(5), c2 VARCHAR(5)) USING $format") + sql("INSERT INTO t VALUES (1234, 1234)") + checkAnswer(spark.table("t"), Row("1234 ", "1234")) + val e1 = intercept[SparkException](sql("INSERT INTO t VALUES (123456, 1)")) + assert(e1.getCause.getMessage.contains( + "input string '123456' exceeds char type length limitation: 5")) + val e2 = intercept[SparkException](sql("INSERT INTO t VALUES (1, 123456)")) + assert(e2.getCause.getMessage.contains( + "input string '123456' exceeds varchar type length limitation: 5")) + } + } + + private def testConditions(df: DataFrame, conditions: Seq[(String, Boolean)]): Unit = { + checkAnswer(df.selectExpr(conditions.map(_._1): _*), Row.fromSeq(conditions.map(_._2))) + } + + test("char type comparison: top-level columns") { + withTable("t") { + sql(s"CREATE TABLE t(c1 CHAR(2), c2 CHAR(5)) USING $format") + sql("INSERT INTO t VALUES ('a', 'a')") + testConditions(spark.table("t"), Seq( + ("c1 = 'a'", true), + ("'a' = c1", true), + ("c1 = 'a '", true), + ("c1 > 'a'", false), + ("c1 IN ('a', 'b')", true), + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: partitioned columns") { + withTable("t") { + sql(s"CREATE TABLE t(i INT, c1 CHAR(2), c2 CHAR(5)) USING $format PARTITIONED BY (c1, c2)") + sql("INSERT INTO t VALUES (1, 'a', 'a')") + testConditions(spark.table("t"), Seq( + ("c1 = 'a'", true), + ("'a' = c1", true), + ("c1 = 'a '", true), + ("c1 > 'a'", false), + ("c1 IN ('a', 'b')", true), + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: join") { + withTable("t1", "t2") { + sql(s"CREATE TABLE t1(c CHAR(2)) USING $format") + sql(s"CREATE TABLE t2(c CHAR(5)) USING $format") + sql("INSERT INTO t1 VALUES ('a')") + sql("INSERT INTO t2 VALUES ('a')") + checkAnswer(sql("SELECT t1.c FROM t1 JOIN t2 ON t1.c = t2.c"), Row("a ")) + } + } + + test("char type comparison: nested in struct") { + withTable("t") { + sql(s"CREATE TABLE t(c1 STRUCT, c2 STRUCT) USING $format") + sql("INSERT INTO t VALUES (struct('a'), struct('a'))") + testConditions(spark.table("t"), Seq( + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: nested in array") { + withTable("t") { + sql(s"CREATE TABLE t(c1 ARRAY, c2 ARRAY) USING $format") + sql("INSERT INTO t VALUES (array('a', 'b'), array('a', 'b'))") + testConditions(spark.table("t"), Seq( + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: nested in struct of array") { + withTable("t") { + sql("CREATE TABLE t(c1 STRUCT>, c2 STRUCT>) " + + s"USING $format") + sql("INSERT INTO t VALUES (struct(array('a', 'b')), struct(array('a', 'b')))") + testConditions(spark.table("t"), Seq( + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: nested in array of struct") { + withTable("t") { + sql("CREATE TABLE t(c1 ARRAY>, c2 ARRAY>) " + + s"USING $format") + sql("INSERT INTO t VALUES (array(struct('a')), array(struct('a')))") + testConditions(spark.table("t"), Seq( + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: nested in array of array") { + withTable("t") { + sql("CREATE TABLE t(c1 ARRAY>, c2 ARRAY>) " + + s"USING $format") + sql("INSERT INTO t VALUES (array(array('a')), array(array('a')))") + testConditions(spark.table("t"), Seq( + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } +} + +// Some basic char/varchar tests which doesn't rely on table implementation. +class BasicCharVarcharTestSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("user-specified schema in cast") { + def assertNoCharType(df: DataFrame): Unit = { + checkAnswer(df, Row("0")) + assert(df.schema.map(_.dataType) == Seq(StringType)) + } + + assertNoCharType(spark.range(1).select($"id".cast("char(5)"))) + assertNoCharType(spark.range(1).select($"id".cast(CharType(5)))) + assertNoCharType(spark.range(1).selectExpr("CAST(id AS CHAR(5))")) + assertNoCharType(sql("SELECT CAST(id AS CHAR(5)) FROM range(1)")) + } + + test("user-specified schema in functions") { + val df = sql("""SELECT from_json('{"a": "str"}', 'a CHAR(5)')""") + checkAnswer(df, Row(Row("str"))) + val schema = df.schema.head.dataType.asInstanceOf[StructType] + assert(schema.map(_.dataType) == Seq(StringType)) + } + + test("user-specified schema in DataFrameReader: file source from Dataset") { + val ds = spark.range(10).map(_.toString) + val df1 = spark.read.schema(new StructType().add("id", CharType(5))).csv(ds) + assert(df1.schema.map(_.dataType) == Seq(StringType)) + val df2 = spark.read.schema("id char(5)").csv(ds) + assert(df2.schema.map(_.dataType) == Seq(StringType)) + } + + test("user-specified schema in DataFrameReader: DSV1") { + def checkSchema(df: DataFrame): Unit = { + val relations = df.queryExecution.analyzed.collect { + case l: LogicalRelation => l.relation + } + assert(relations.length == 1) + assert(relations.head.schema.map(_.dataType) == Seq(StringType)) + } + + checkSchema(spark.read.schema(new StructType().add("id", CharType(5))) + .format(classOf[SimpleInsertSource].getName).load()) + checkSchema(spark.read.schema("id char(5)") + .format(classOf[SimpleInsertSource].getName).load()) + } + + test("user-specified schema in DataFrameReader: DSV2") { + def checkSchema(df: DataFrame): Unit = { + val tables = df.queryExecution.analyzed.collect { + case d: DataSourceV2Relation => d.table + } + assert(tables.length == 1) + assert(tables.head.schema.map(_.dataType) == Seq(StringType)) + } + + checkSchema(spark.read.schema(new StructType().add("id", CharType(5))) + .format(classOf[SchemaRequiredDataSource].getName).load()) + checkSchema(spark.read.schema("id char(5)") + .format(classOf[SchemaRequiredDataSource].getName).load()) + } +} + +class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSparkSession { + override def format: String = "parquet" + override protected def sparkConf: SparkConf = { + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "parquet") + } +} + +class DSV2CharVarcharTestSuite extends CharVarcharTestSuite + with SharedSparkSession { + override def format: String = "foo" + protected override def sparkConf = { + super.sparkConf + .set("spark.sql.catalog.testcat", classOf[InMemoryPartitionTableCatalog].getName) + .set(SQLConf.DEFAULT_CATALOG.key, "testcat") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 92c114e116d0..f9bc1f3c603c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.sources.SimpleScanSource -import org.apache.spark.sql.types.{CharType, DoubleType, HIVE_TYPE_STRING, IntegerType, LongType, MetadataBuilder, StringType, StructField, StructType} +import org.apache.spark.sql.types.{CharType, DoubleType, IntegerType, LongType, StringType, StructField, StructType} class PlanResolutionSuite extends AnalysisTest { import CatalystSqlParser._ @@ -1090,9 +1090,7 @@ class PlanResolutionSuite extends AnalysisTest { } val sql = s"ALTER TABLE v1HiveTable ALTER COLUMN i TYPE char(1)" - val builder = new MetadataBuilder - builder.putString(HIVE_TYPE_STRING, CharType(1).catalogString) - val newColumnWithCleanedType = StructField("i", StringType, true, builder.build()) + val newColumnWithCleanedType = StructField("i", CharType(1), true) val expected = AlterTableChangeColumnCommand( TableIdentifier("v1HiveTable", Some("default")), "i", newColumnWithCleanedType) val parsed = parseAndResolve(sql) @@ -1533,44 +1531,6 @@ class PlanResolutionSuite extends AnalysisTest { } } - test("SPARK-31147: forbid CHAR type in non-Hive tables") { - def checkFailure(t: String, provider: String): Unit = { - val types = Seq( - "CHAR(2)", - "ARRAY", - "MAP", - "MAP", - "STRUCT") - types.foreach { tpe => - intercept[AnalysisException] { - parseAndResolve(s"CREATE TABLE $t(col $tpe) USING $provider") - } - intercept[AnalysisException] { - parseAndResolve(s"REPLACE TABLE $t(col $tpe) USING $provider") - } - intercept[AnalysisException] { - parseAndResolve(s"CREATE OR REPLACE TABLE $t(col $tpe) USING $provider") - } - intercept[AnalysisException] { - parseAndResolve(s"ALTER TABLE $t ADD COLUMN col $tpe") - } - intercept[AnalysisException] { - parseAndResolve(s"ALTER TABLE $t ADD COLUMN col $tpe") - } - intercept[AnalysisException] { - parseAndResolve(s"ALTER TABLE $t ALTER COLUMN col TYPE $tpe") - } - intercept[AnalysisException] { - parseAndResolve(s"ALTER TABLE $t REPLACE COLUMNS (col $tpe)") - } - } - } - - checkFailure("v1Table", v1Format) - checkFailure("v2Table", v2Format) - checkFailure("testcat.tab", "foo") - } - private def compareNormalized(plan1: LogicalPlan, plan2: LogicalPlan): Unit = { /** * Normalizes plans: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 9a95bf770772..ca3e71466581 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -127,7 +128,7 @@ class TableScanSuite extends DataSourceTest with SharedSparkSession { Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", - s"char_$i", + s"char_$i".padTo(18, ' '), Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), @@ -206,10 +207,6 @@ class TableScanSuite extends DataSourceTest with SharedSparkSession { (2 to 10).map(i => Row(i, i - 1)).toSeq) test("Schema and all fields") { - def hiveMetadata(dt: String): Metadata = { - new MetadataBuilder().putString(HIVE_TYPE_STRING, dt).build() - } - val expectedSchema = StructType( StructField("string$%Field", StringType, true) :: StructField("binaryField", BinaryType, true) :: @@ -224,8 +221,8 @@ class TableScanSuite extends DataSourceTest with SharedSparkSession { StructField("decimalField2", DecimalType(9, 2), true) :: StructField("dateField", DateType, true) :: StructField("timestampField", TimestampType, true) :: - StructField("varcharField", StringType, true, hiveMetadata("varchar(12)")) :: - StructField("charField", StringType, true, hiveMetadata("char(18)")) :: + StructField("varcharField", VarcharType(12), true) :: + StructField("charField", CharType(18), true) :: StructField("arrayFieldSimple", ArrayType(IntegerType), true) :: StructField("arrayFieldComplex", ArrayType( @@ -248,7 +245,8 @@ class TableScanSuite extends DataSourceTest with SharedSparkSession { Nil ) - assert(expectedSchema == spark.table("tableWithSchema").schema) + assert(CharVarcharUtils.replaceCharVarcharWithStringInSchema(expectedSchema) == + spark.table("tableWithSchema").schema) withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { checkAnswer( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index b30492802495..da37b6168895 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -90,6 +90,7 @@ class HiveSessionStateBuilder( PreprocessTableCreation(session) +: PreprocessTableInsertion +: DataSourceAnalysis +: + ApplyCharTypePadding +: HiveAnalysis +: customPostHocResolutionRules diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index b2f0867114ba..bada131c8ba6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -978,19 +978,14 @@ private[hive] class HiveClientImpl( private[hive] object HiveClientImpl extends Logging { /** Converts the native StructField to Hive's FieldSchema. */ def toHiveColumn(c: StructField): FieldSchema = { - val typeString = if (c.metadata.contains(HIVE_TYPE_STRING)) { - c.metadata.getString(HIVE_TYPE_STRING) - } else { - // replace NullType to HiveVoidType since Hive parse void not null. - HiveVoidType.replaceVoidType(c.dataType).catalogString - } + val typeString = HiveVoidType.replaceVoidType(c.dataType).catalogString new FieldSchema(c.name, typeString, c.getComment().orNull) } /** Get the Spark SQL native DataType from Hive's FieldSchema. */ private def getSparkSQLDataType(hc: FieldSchema): DataType = { try { - CatalystSqlParser.parseDataType(hc.getType) + CatalystSqlParser.parseRawDataType(hc.getType) } catch { case e: ParseException => throw new SparkException( @@ -1001,18 +996,10 @@ private[hive] object HiveClientImpl extends Logging { /** Builds the native StructField from Hive's FieldSchema. */ def fromHiveColumn(hc: FieldSchema): StructField = { val columnType = getSparkSQLDataType(hc) - val replacedVoidType = HiveVoidType.replaceVoidType(columnType) - val metadata = if (hc.getType != replacedVoidType.catalogString) { - new MetadataBuilder().putString(HIVE_TYPE_STRING, hc.getType).build() - } else { - Metadata.empty - } - val field = StructField( name = hc.getName, dataType = columnType, - nullable = true, - metadata = metadata) + nullable = true) Option(hc.getComment).map(field.withComment).getOrElse(field) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/HiveCharVarcharTestSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/HiveCharVarcharTestSuite.scala new file mode 100644 index 000000000000..55d305fda4f9 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/HiveCharVarcharTestSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.hive.test.TestHiveSingleton + +class HiveCharVarcharTestSuite extends CharVarcharTestSuite with TestHiveSingleton { + + // The default Hive serde doesn't support nested null values. + override def format: String = "hive OPTIONS(fileFormat='parquet')" + + private var originalPartitionMode = "" + + override protected def beforeAll(): Unit = { + super.beforeAll() + originalPartitionMode = spark.conf.get("hive.exec.dynamic.partition.mode", "") + spark.conf.set("hive.exec.dynamic.partition.mode", "nonstrict") + } + + override protected def afterAll(): Unit = { + if (originalPartitionMode == "") { + spark.conf.unset("hive.exec.dynamic.partition.mode") + } else { + spark.conf.set("hive.exec.dynamic.partition.mode", originalPartitionMode) + } + super.afterAll() + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 8f71ba3337aa..1a6f6843d391 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -113,24 +113,19 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils { .add("c9", "date") .add("c10", "timestamp") .add("c11", "string") - .add("c12", "string", true, - new MetadataBuilder().putString(HIVE_TYPE_STRING, "char(10)").build()) - .add("c13", "string", true, - new MetadataBuilder().putString(HIVE_TYPE_STRING, "varchar(10)").build()) + .add("c12", CharType(10), true) + .add("c13", VarcharType(10), true) .add("c14", "binary") .add("c15", "decimal") .add("c16", "decimal(10)") .add("c17", "decimal(10,2)") .add("c18", "array") .add("c19", "array") - .add("c20", "array", true, - new MetadataBuilder().putString(HIVE_TYPE_STRING, "array").build()) + .add("c20", ArrayType(CharType(10)), true) .add("c21", "map") - .add("c22", "map", true, - new MetadataBuilder().putString(HIVE_TYPE_STRING, "map").build()) + .add("c22", MapType(IntegerType, CharType(10)), true) .add("c23", "struct") - .add("c24", "struct", true, - new MetadataBuilder().putString(HIVE_TYPE_STRING, "struct").build()) + .add("c24", new StructType().add("c", VarcharType(10)).add("d", "int"), true) assert(schema == expectedSchema) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index b8b1da4cb9db..2dfb8bb55259 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -2251,8 +2251,8 @@ class HiveDDLSuite ) sql("ALTER TABLE tab ADD COLUMNS (c5 char(10))") - assert(spark.table("tab").schema.find(_.name == "c5") - .get.metadata.getString("HIVE_TYPE_STRING") == "char(10)") + assert(spark.sharedState.externalCatalog.getTable("default", "tab") + .schema.find(_.name == "c5").get.dataType == CharType(10)) } } }