diff --git a/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/deltaMerge.scala b/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/deltaMerge.scala index b664cf92196..30856385d6b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/deltaMerge.scala +++ b/spark/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/deltaMerge.scala @@ -602,15 +602,26 @@ object DeltaMergeInto { // clause, then merge this schema with the target to give the final schema. def filterSchema(sourceSchema: StructType, basePath: Seq[String]): StructType = StructType(sourceSchema.flatMap { field => - val fieldPath = basePath :+ field.name.toLowerCase(Locale.ROOT) - val childAssignedInMergeClause = assignments.exists(_.startsWith(fieldPath)) + val fieldPath = basePath :+ field.name + + // Helper method to check if a given field path is a prefix of another path. Delegates + // equality to conf.resolver to correctly handle case sensitivity. + def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean = + prefix.length <= path.length && prefix.zip(path).forall { + case (prefixNamePart, pathNamePart) => conf.resolver(prefixNamePart, pathNamePart) + } + + // Helper method to check if a given field path is equal to another path. + def isEqual(path1: Seq[String], path2: Seq[String]): Boolean = + path1.length == path2.length && isPrefix(path1, path2) + field.dataType match { // Specifically assigned to in one clause: always keep, including all nested attributes - case _ if assignments.contains(fieldPath) => Some(field) + case _ if assignments.exists(isEqual(_, fieldPath)) => Some(field) // If this is a struct and one of the children is being assigned to in a merge clause, // keep it and continue filtering children. - case struct: StructType if childAssignedInMergeClause => + case struct: StructType if assignments.exists(isPrefix(fieldPath, _)) => Some(field.copy(dataType = filterSchema(struct, fieldPath))) // The field isn't assigned to directly or indirectly (i.e. its children) in any non-* // clause. Check if it should be kept with any * action. diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala b/spark/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala index 9e17bbe5819..389c89d0654 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala @@ -410,7 +410,9 @@ case class PreprocessTableMerge(override val conf: SQLConf) if (implicitColumns.isEmpty) { return (allActions, Set[String]()) } - assert(finalSchema.size == allActions.size) + assert(finalSchema.size == allActions.size, + "Invalid number of columns in INSERT clause with generated columns. Expected schema: " + + s"$finalSchema, INSERT actions: $allActions") val track = mutable.Set[String]() diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/GeneratedColumnSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/GeneratedColumnSuite.scala index 5f2dd60fa98..8f8dc91fd9f 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/GeneratedColumnSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/GeneratedColumnSuite.scala @@ -1724,6 +1724,42 @@ trait GeneratedColumnSuiteBase extends GeneratedColumnTest { } } + test("MERGE INSERT with schema evolution on different name case") { + withTableName("source") { src => + withTableName("target") { tgt => + createTable( + tableName = src, + path = None, + schemaString = "c1 INT, c2 INT", + generatedColumns = Map.empty, + partitionColumns = Seq.empty + ) + sql(s"INSERT INTO ${src} values (2, 4);") + createTable( + tableName = tgt, + path = None, + schemaString = "c1 INT, c3 INT", + generatedColumns = Map("c3" -> "c1 + 1"), + partitionColumns = Seq.empty + ) + sql(s"INSERT INTO ${tgt} values (1, 2);") + + withSQLConf(("spark.databricks.delta.schema.autoMerge.enabled", "true")) { + sql(s""" + |MERGE INTO ${tgt} + |USING ${src} + |on ${tgt}.c1 = ${src}.c1 + |WHEN NOT MATCHED THEN INSERT (c1, C2) VALUES (${src}.c1, ${src}.c2) + |""".stripMargin) + } + checkAnswer( + sql(s"SELECT * FROM ${tgt}"), + Seq(Row(1, 2, null), Row(2, 3, 4)) + ) + } + } + } + test("generated columns with cdf") { val tableName1 = "gcEnabledCDCOn" val tableName2 = "gcEnabledCDCOff" diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSchemaEvolutionSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSchemaEvolutionSuite.scala index c81c174f778..1882884185b 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSchemaEvolutionSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSchemaEvolutionSuite.scala @@ -451,6 +451,24 @@ trait MergeIntoSchemaEvolutionBaseTests { expectedWithoutEvolution = ((0, 0) +: (2, 2) +: (3, 30) +: (1, 1) +: Nil).toDF("key", "value") ) + testEvolution(s"case-insensitive insert")( + targetData = Seq((0, 0), (1, 10), (3, 30)).toDF("key", "value"), + sourceData = Seq((1, 1), (2, 2)).toDF("key", "VALUE"), + clauses = insert("(key, value, VALUE) VALUES (s.key, s.value, s.VALUE)") :: Nil, + expected = ((0, 0) +: (1, 10) +: (3, 30) +: (2, 2) +: Nil).toDF("key", "value"), + expectedWithoutEvolution = ((0, 0) +: (1, 10) +: (3, 30) +: (2, 2) +: Nil).toDF("key", "value"), + confs = Seq(SQLConf.CASE_SENSITIVE.key -> "false") + ) + + testEvolution(s"case-sensitive insert")( + targetData = Seq((0, 0), (1, 10), (3, 30)).toDF("key", "value"), + sourceData = Seq((1, 1), (2, 2)).toDF("key", "VALUE"), + clauses = insert("(key, value, VALUE) VALUES (s.key, s.value, s.VALUE)") :: Nil, + expectErrorContains = "Cannot resolve s.value in INSERT clause", + expectErrorWithoutEvolutionContains = "Cannot resolve s.value in INSERT clause", + confs = Seq(SQLConf.CASE_SENSITIVE.key -> "true") + ) + testEvolution("evolve partitioned table")( targetData = Seq((0, 0), (1, 10), (3, 30)).toDF("key", "value"), sourceData = Seq((1, 1, "extra1"), (2, 2, "extra2")).toDF("key", "value", "extra"),