diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala index f4502c924572..447769be3000 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala @@ -53,25 +53,53 @@ case class ResolveDefaultColumns( // This field stores the enclosing INSERT INTO command, once we find one. var enclosingInsert: Option[InsertIntoStatement] = None + // This field stores the schema of the target table of the above command. + var insertTableSchemaWithoutPartitionColumns: Option[StructType] = None - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( - (_ => SQLConf.get.enableDefaultColumns), ruleId) { - case i@InsertIntoStatement(_, _, _, _, _, _) - if i.query.collectFirst { case u: UnresolvedInlineTable => u }.isDefined => - enclosingInsert = Some(i) - i + override def apply(plan: LogicalPlan): LogicalPlan = { + // Initialize by clearing our reference to the enclosing INSERT INTO command. + enclosingInsert = None + insertTableSchemaWithoutPartitionColumns = None + // Traverse the logical query plan in preorder (top-down). + plan.resolveOperatorsWithPruning( + (_ => SQLConf.get.enableDefaultColumns), ruleId) { + case i@InsertIntoStatement(_, _, _, _, _, _) + if i.query.collectFirst { case u: UnresolvedInlineTable + if u.rows.nonEmpty && u.rows.forall(_.size == u.rows(0).size) => u + }.isDefined => + enclosingInsert = Some(i) + insertTableSchemaWithoutPartitionColumns = getInsertTableSchemaWithoutPartitionColumns + val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(i) + regenerated + + case table: UnresolvedInlineTable + if enclosingInsert.isDefined => + val expanded: UnresolvedInlineTable = addMissingDefaultColumnValues(table).getOrElse(table) + val replaced: LogicalPlan = + replaceExplicitDefaultColumnValues(analyzer, expanded).getOrElse(table) + replaced + + case i@InsertIntoStatement(_, _, _, project: Project, _, _) => + enclosingInsert = Some(i) + insertTableSchemaWithoutPartitionColumns = getInsertTableSchemaWithoutPartitionColumns + val expanded: Project = addMissingDefaultColumnValues(project).getOrElse(project) + val replaced: Option[LogicalPlan] = replaceExplicitDefaultColumnValues(analyzer, expanded) + val updated: InsertIntoStatement = + if (replaced.isDefined) i.copy(query = replaced.get) else i + val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(updated) + regenerated + } + } - case table: UnresolvedInlineTable - if enclosingInsert.isDefined && - table.rows.nonEmpty && table.rows.forall(_.size == table.rows(0).size) => - val expanded: UnresolvedInlineTable = addMissingDefaultColumnValues(table).getOrElse(table) - replaceExplicitDefaultColumnValues(analyzer, expanded).getOrElse(table) - - case i@InsertIntoStatement(_, _, _, project: Project, _, _) => - enclosingInsert = Some(i) - val expanded: Project = addMissingDefaultColumnValues(project).getOrElse(project) - val replaced: Option[LogicalPlan] = replaceExplicitDefaultColumnValues(analyzer, expanded) - if (replaced.isDefined) i.copy(query = replaced.get) else i + // Helper method to regenerate user-specified columns of an InsertIntoStatement based on the names + // in the insertTableSchemaWithoutPartitionColumns field of this class. + private def regenerateUserSpecifiedCols(i: InsertIntoStatement): InsertIntoStatement = { + if (i.userSpecifiedCols.nonEmpty && insertTableSchemaWithoutPartitionColumns.isDefined) { + i.copy( + userSpecifiedCols = insertTableSchemaWithoutPartitionColumns.get.fields.map(_.name)) + } else { + i + } } // Helper method to check if an expression is an explicit DEFAULT column reference. @@ -87,7 +115,7 @@ case class ResolveDefaultColumns( table: UnresolvedInlineTable): Option[UnresolvedInlineTable] = { assert(enclosingInsert.isDefined) val numQueryOutputs: Int = table.rows(0).size - val schema = getInsertTableSchemaWithoutPartitionColumns.getOrElse(return None) + val schema = insertTableSchemaWithoutPartitionColumns.getOrElse(return None) val newDefaultExpressions: Seq[Expression] = getDefaultExpressions(numQueryOutputs, schema) val newNames: Seq[String] = schema.fields.drop(numQueryOutputs).map { _.name } if (newDefaultExpressions.nonEmpty) { @@ -104,7 +132,7 @@ case class ResolveDefaultColumns( */ private def addMissingDefaultColumnValues(project: Project): Option[Project] = { val numQueryOutputs: Int = project.projectList.size - val schema = getInsertTableSchemaWithoutPartitionColumns.getOrElse(return None) + val schema = insertTableSchemaWithoutPartitionColumns.getOrElse(return None) val newDefaultExpressions: Seq[Expression] = getDefaultExpressions(numQueryOutputs, schema) val newAliases: Seq[NamedExpression] = newDefaultExpressions.zip(schema.fields).map { @@ -122,16 +150,21 @@ case class ResolveDefaultColumns( */ private def getDefaultExpressions(numQueryOutputs: Int, schema: StructType): Seq[Expression] = { val remainingFields: Seq[StructField] = schema.fields.drop(numQueryOutputs) - val numDefaultExpressionsToAdd: Int = { - if (SQLConf.get.useNullsForMissingDefaultColumnValues) { - remainingFields.size - } else { - remainingFields.takeWhile(_.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)).size - } - } + val numDefaultExpressionsToAdd = getStructFieldsForDefaultExpressions(remainingFields).size Seq.fill(numDefaultExpressionsToAdd)(UnresolvedAttribute(CURRENT_DEFAULT_COLUMN_NAME)) } + /** + * This is a helper for the getDefaultExpressions methods above. + */ + private def getStructFieldsForDefaultExpressions(fields: Seq[StructField]): Seq[StructField] = { + if (SQLConf.get.useNullsForMissingDefaultColumnValues) { + fields + } else { + fields.takeWhile(_.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) + } + } + /** * Replaces unresolved DEFAULT column references with corresponding values in a logical plan. */ @@ -139,7 +172,7 @@ case class ResolveDefaultColumns( analyzer: Analyzer, input: LogicalPlan): Option[LogicalPlan] = { assert(enclosingInsert.isDefined) - val schema = getInsertTableSchemaWithoutPartitionColumns.getOrElse(return None) + val schema = insertTableSchemaWithoutPartitionColumns.getOrElse(return None) val columnNames: Seq[String] = schema.fields.map { _.name } val defaultExpressions: Seq[Expression] = schema.fields.map { case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => @@ -193,7 +226,7 @@ case class ResolveDefaultColumns( } /** - * Replaces unresolved DEFAULT column references with corresponding values in an projection. + * Replaces unresolved DEFAULT column references with corresponding values in a projection. */ private def replaceExplicitDefaultColumnValues( defaultExpressions: Seq[Expression], @@ -230,21 +263,48 @@ case class ResolveDefaultColumns( */ private def getInsertTableSchemaWithoutPartitionColumns: Option[StructType] = { assert(enclosingInsert.isDefined) - try { - val tableName = enclosingInsert.get.table match { - case r: UnresolvedRelation => TableIdentifier(r.name) - case r: UnresolvedCatalogRelation => r.tableMeta.identifier - case _ => return None - } - val lookup = catalog.lookupRelation(tableName) - lookup match { - case SubqueryAlias(_, r: UnresolvedCatalogRelation) => - Some(StructType(r.tableMeta.schema.fields.dropRight( - enclosingInsert.get.partitionSpec.size))) - case _ => None - } + val tableName = enclosingInsert.get.table match { + case r: UnresolvedRelation => TableIdentifier(r.name) + case r: UnresolvedCatalogRelation => r.tableMeta.identifier + case _ => return None + } + // Lookup the relation from the catalog by name. This either succeeds or returns some "not + // found" error. In the latter cases, return out of this rule without changing anything and let + // the analyzer return a proper error message elsewhere. + val lookup: LogicalPlan = try { + catalog.lookupRelation(tableName) } catch { - case _: NoSuchTableException => None + case _: AnalysisException => return None + } + val schema: StructType = lookup match { + case SubqueryAlias(_, r: UnresolvedCatalogRelation) => + StructType(r.tableMeta.schema.fields.dropRight( + enclosingInsert.get.partitionSpec.size)) + case _ => return None + } + // Rearrange the columns in the result schema to match the order of the explicit column list, + // if any. + val userSpecifiedCols: Seq[String] = enclosingInsert.get.userSpecifiedCols + if (userSpecifiedCols.isEmpty) { + return Some(schema) } + def normalize(str: String) = { + if (SQLConf.get.caseSensitiveAnalysis) str else str.toLowerCase() + } + val colNamesToFields: Map[String, StructField] = + schema.fields.map { + field: StructField => normalize(field.name) -> field + }.toMap + val userSpecifiedFields: Seq[StructField] = + userSpecifiedCols.map { + name: String => colNamesToFields.getOrElse(normalize(name), return None) + } + val userSpecifiedColNames: Set[String] = userSpecifiedCols.toSet + val nonUserSpecifiedFields: Seq[StructField] = + schema.fields.filter { + field => !userSpecifiedColNames.contains(field.name) + } + Some(StructType(userSpecifiedFields ++ + getStructFieldsForDefaultExpressions(nonUserSpecifiedFields))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 2483055880e9..222e195719d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -865,26 +865,26 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { withTable("t") { sql("create table t(i boolean, s bigint) using parquet") sql("insert into t values(true)") - checkAnswer(sql("select s from t where i = true"), Seq(Row(null))) + checkAnswer(spark.table("t"), Row(true, null)) } } // The default value for the DEFAULT keyword is the NULL literal. withTable("t") { sql("create table t(i boolean, s bigint) using parquet") sql("insert into t values(true, default)") - checkAnswer(sql("select s from t where i = true"), Seq(null).map(i => Row(i))) + checkAnswer(spark.table("t"), Row(true, null)) } // There is a complex expression in the default value. withTable("t") { sql("create table t(i boolean, s string default concat('abc', 'def')) using parquet") sql("insert into t values(true, default)") - checkAnswer(sql("select s from t where i = true"), Seq("abcdef").map(i => Row(i))) + checkAnswer(spark.table("t"), Row(true, "abcdef")) } // The default value parses correctly and the provided value type is different but coercible. withTable("t") { sql("create table t(i boolean, s bigint default 42) using parquet") sql("insert into t values(false)") - checkAnswer(sql("select s from t where i = false"), Seq(42L).map(i => Row(i))) + checkAnswer(spark.table("t"), Row(false, 42L)) } // There are two trailing default values referenced implicitly by the INSERT INTO statement. withTable("t") { @@ -894,74 +894,74 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { } // The table has a partitioning column and a default value is injected. withTable("t") { - sql("create table t(i boolean, s bigint, q int default 42 ) using parquet partitioned by (i)") + sql("create table t(i boolean, s bigint, q int default 42) using parquet partitioned by (i)") sql("insert into t partition(i='true') values(5, default)") - checkAnswer(sql("select s from t where i = true"), Seq(5).map(i => Row(i))) + checkAnswer(spark.table("t"), Row(5, 42, true)) } // The table has a partitioning column and a default value is added per an explicit reference. withTable("t") { sql("create table t(i boolean, s bigint default 42) using parquet partitioned by (i)") sql("insert into t partition(i='true') values(default)") - checkAnswer(sql("select s from t where i = true"), Seq(42L).map(i => Row(i))) + checkAnswer(spark.table("t"), Row(42L, true)) } // The default value parses correctly as a constant but non-literal expression. withTable("t") { sql("create table t(i boolean, s bigint default 41 + 1) using parquet") sql("insert into t values(false, default)") - checkAnswer(sql("select s from t where i = false"), Seq(42L).map(i => Row(i))) + checkAnswer(spark.table("t"), Row(false, 42L)) } // Explicit defaults may appear in different positions within the inline table provided as input // to the INSERT INTO statement. withTable("t") { sql("create table t(i boolean default false, s bigint default 42) using parquet") sql("insert into t values(false, default), (default, 42)") - checkAnswer(sql("select s from t where i = false"), Seq(42L, 42L).map(i => Row(i))) + checkAnswer(spark.table("t"), Seq(Row(false, 42L), Row(false, 42L))) } // There is an explicit default value provided in the INSERT INTO statement in the VALUES, // with an alias over the VALUES. withTable("t") { sql("create table t(i boolean, s bigint default 42) using parquet") sql("insert into t select * from values (false, default) as tab(col, other)") - checkAnswer(sql("select s from t where i = false"), Seq(42L).map(i => Row(i))) + checkAnswer(spark.table("t"), Row(false, 42L)) } // The explicit default value arrives first before the other value. withTable("t") { sql("create table t(i boolean default false, s bigint) using parquet") sql("insert into t values (default, 43)") - checkAnswer(sql("select s from t where i = false"), Seq(43L).map(i => Row(i))) + checkAnswer(spark.table("t"), Row(false, 43L)) } // The 'create table' statement provides the default parameter first. withTable("t") { sql("create table t(i boolean default false, s bigint) using parquet") sql("insert into t values (default, 43)") - checkAnswer(sql("select s from t where i = false"), Seq(43L).map(i => Row(i))) + checkAnswer(spark.table("t"), Row(false, 43L)) } // The explicit default value is provided in the wrong order (first instead of second), but // this is OK because the provided default value evaluates to literal NULL. withTable("t") { sql("create table t(i boolean, s bigint default 42) using parquet") sql("insert into t values (default, 43)") - checkAnswer(sql("select s from t where i is null"), Seq(43L).map(i => Row(i))) + checkAnswer(spark.table("t"), Row(null, 43L)) } // There is an explicit default value provided in the INSERT INTO statement as a SELECT. // This is supported. withTable("t") { sql("create table t(i boolean, s bigint default 42) using parquet") sql("insert into t select false, default") - checkAnswer(sql("select s from t where i = false"), Seq(42L).map(i => Row(i))) + checkAnswer(spark.table("t"), Row(false, 42L)) } // There is a complex query plan in the SELECT query in the INSERT INTO statement. withTable("t") { sql("create table t(i boolean default false, s bigint default 42) using parquet") sql("insert into t select col, count(*) from values (default, default) " + "as tab(col, other) group by 1") - checkAnswer(sql("select s from t where i = false"), Seq(1).map(i => Row(i))) + checkAnswer(spark.table("t"), Row(false, 1)) } // The explicit default reference resolves successfully with nested table subqueries. withTable("t") { sql("create table t(i boolean default false, s bigint) using parquet") sql("insert into t select * from (select * from values(default, 42))") - checkAnswer(sql("select s from t where i = false"), Seq(42L).map(i => Row(i))) + checkAnswer(spark.table("t"), Row(false, 42L)) } // There are three column types exercising various combinations of implicit and explicit // default column value references in the 'insert into' statements. Note these tests depend on @@ -1086,7 +1086,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { } // The table has a partitioning column with a default value; this is not allowed. withTable("t") { - sql("create table t(i boolean default true, s bigint, q int default 42 ) " + + sql("create table t(i boolean default true, s bigint, q int default 42) " + "using parquet partitioned by (i)") assert(intercept[ParseException] { sql("insert into t partition(i=default) values(5, default)") @@ -1105,6 +1105,156 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { } } + test("INSERT INTO with user specified columns and defaults: positive tests") { + Seq( + "insert into t (i, s) values (true, default)", + "insert into t (s, i) values (default, true)", + "insert into t (i) values (true)", + "insert into t (i) values (default)", + "insert into t (s) values (default)", + "insert into t (s) select default from (select 1)", + "insert into t (i) select true from (select 1)" + ).foreach { insert => + withTable("t") { + sql("create table t(i boolean default true, s bigint default 42) using parquet") + sql(insert) + checkAnswer(spark.table("t"), Row(true, 42L)) + } + } + // The table is partitioned and we insert default values with explicit column names. + withTable("t") { + sql("create table t(i boolean, s bigint default 4, q int default 42) using parquet " + + "partitioned by (i)") + sql("insert into t partition(i='true') (s) values(5)") + sql("insert into t partition(i='false') (q) select 43") + sql("insert into t partition(i='false') (q) select default") + checkAnswer(spark.table("t"), + Seq(Row(5, 42, true), + Row(4, 43, false), + Row(4, 42, false))) + } + // When the CASE_SENSITIVE configuration is disabled, then using different cases for the + // required and provided column names is successful. + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTable("t") { + sql("create table t(i boolean, s bigint default 42, q int default 43) using parquet") + sql("insert into t (I, Q) select true from (select 1)") + checkAnswer(spark.table("t"), Row(true, 42L, 43)) + } + } + // When the USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES configuration is enabled, and no + // explicit DEFAULT value is available when the INSERT INTO statement provides fewer + // values than expected, NULL values are appended in their place. + withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "true") { + withTable("t") { + sql("create table t(i boolean, s bigint) using parquet") + sql("insert into t (i) values (true)") + checkAnswer(spark.table("t"), Row(true, null)) + } + withTable("t") { + sql("create table t(i boolean default true, s bigint) using parquet") + sql("insert into t (i) values (default)") + checkAnswer(spark.table("t"), Row(true, null)) + } + withTable("t") { + sql("create table t(i boolean, s bigint default 42) using parquet") + sql("insert into t (s) values (default)") + checkAnswer(spark.table("t"), Row(null, 42L)) + } + withTable("t") { + sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") + sql("insert into t partition(i='true') (s) values(5)") + sql("insert into t partition(i='false') (q) select 43") + sql("insert into t partition(i='false') (q) select default") + checkAnswer(spark.table("t"), + Seq(Row(5, null, true), + Row(null, 43, false), + Row(null, null, false))) + } + } + } + + test("INSERT INTO with user specified columns and defaults: negative tests") { + val addOneColButExpectedTwo = "target table has 2 column(s) but the inserted data has 1 col" + val addTwoColButExpectedThree = "target table has 3 column(s) but the inserted data has 2 col" + // The missing columns in these INSERT INTO commands do not have explicit default values. + withTable("t") { + sql("create table t(i boolean, s bigint) using parquet") + assert(intercept[AnalysisException] { + sql("insert into t (i) values (true)") + }.getMessage.contains(addOneColButExpectedTwo)) + } + withTable("t") { + sql("create table t(i boolean default true, s bigint) using parquet") + assert(intercept[AnalysisException] { + sql("insert into t (i) values (default)") + }.getMessage.contains(addOneColButExpectedTwo)) + } + withTable("t") { + sql("create table t(i boolean, s bigint default 42) using parquet") + assert(intercept[AnalysisException] { + sql("insert into t (s) values (default)") + }.getMessage.contains(addOneColButExpectedTwo)) + } + withTable("t") { + sql("create table t(i boolean, s bigint, q int default 43) using parquet") + assert(intercept[AnalysisException] { + sql("insert into t (i, q) select true from (select 1)") + }.getMessage.contains(addTwoColButExpectedThree)) + } + // When the USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES configuration is disabled, and no + // explicit DEFAULT value is available when the INSERT INTO statement provides fewer + // values than expected, the INSERT INTO command fails to execute. + withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "false") { + withTable("t") { + sql("create table t(i boolean, s bigint) using parquet") + assert(intercept[AnalysisException] { + sql("insert into t (i) values (true)") + }.getMessage.contains(addOneColButExpectedTwo)) + } + withTable("t") { + sql("create table t(i boolean default true, s bigint) using parquet") + assert(intercept[AnalysisException] { + sql("insert into t (i) values (default)") + }.getMessage.contains(addOneColButExpectedTwo)) + } + withTable("t") { + sql("create table t(i boolean, s bigint default 42) using parquet") + assert(intercept[AnalysisException] { + sql("insert into t (s) values (default)") + }.getMessage.contains(addOneColButExpectedTwo)) + } + withTable("t") { + sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") + assert(intercept[AnalysisException] { + sql("insert into t partition(i='true') (s) values(5)") + }.getMessage.contains(addTwoColButExpectedThree)) + } + withTable("t") { + sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") + assert(intercept[AnalysisException] { + sql("insert into t partition(i='false') (q) select 43") + }.getMessage.contains(addTwoColButExpectedThree)) + } + withTable("t") { + sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") + assert(intercept[AnalysisException] { + sql("insert into t partition(i='false') (q) select default") + }.getMessage.contains(addTwoColButExpectedThree)) + } + } + // When the CASE_SENSITIVE configuration is enabled, then using different cases for the required + // and provided column names results in an analysis error. + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTable("t") { + sql("create table t(i boolean default true, s bigint default 42) using parquet") + assert(intercept[AnalysisException] { + sql("insert into t (I) select true from (select 1)") + }.getMessage.contains("Cannot resolve column name I")) + } + } + } + test("Stop task set if FileAlreadyExistsException was thrown") { Seq(true, false).foreach { fastFail => withSQLConf("fs.file.impl" -> classOf[FileExistingTestFileSystem].getName,