Skip to content

Commit

Permalink
[SPARK-48683][SQL] Fix schema evolution with df.mergeInto losing `w…
Browse files Browse the repository at this point in the history
…hen` clauses

### What changes were proposed in this pull request?

This PR fixes an issue in the `DataFrame.mergeInto` API where defined `when` clauses are lost after calling the `withSchemaEvoltuion()` method.
The issue is caused by not copying over existing clauses to the new writer.

### Why are the changes needed?

It fixes a bug.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New test.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#47055 from xupefei/mergeinto-bugfix.

Authored-by: Paddy Xu <xupaddy@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
xupefei authored and HyukjinKwon committed Jun 24, 2024
1 parent 8b16196 commit e459674
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 15 deletions.
33 changes: 18 additions & 15 deletions sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class MergeIntoWriter[T] private[sql] (
table: String,
ds: Dataset[T],
on: Column,
schemaEvolutionEnabled: Boolean = false) {
private[sql] val schemaEvolutionEnabled: Boolean = false) {

private val df: DataFrame = ds.toDF()

Expand Down Expand Up @@ -172,6 +172,9 @@ class MergeIntoWriter[T] private[sql] (
*/
def withSchemaEvolution(): MergeIntoWriter[T] = {
new MergeIntoWriter[T](this.table, this.ds, this.on, schemaEvolutionEnabled = true)
.withNewMatchedActions(this.matchedActions: _*)
.withNewNotMatchedActions(this.notMatchedActions: _*)
.withNewNotMatchedBySourceActions(this.notMatchedBySourceActions: _*)
}

/**
Expand All @@ -196,18 +199,18 @@ class MergeIntoWriter[T] private[sql] (
qe.assertCommandExecuted()
}

private[sql] def withNewMatchedAction(action: MergeAction): MergeIntoWriter[T] = {
this.matchedActions = this.matchedActions :+ action
private[sql] def withNewMatchedActions(actions: MergeAction*): MergeIntoWriter[T] = {
this.matchedActions ++= actions
this
}

private[sql] def withNewNotMatchedAction(action: MergeAction): MergeIntoWriter[T] = {
this.notMatchedActions = this.notMatchedActions :+ action
private[sql] def withNewNotMatchedActions(actions: MergeAction*): MergeIntoWriter[T] = {
this.notMatchedActions ++= actions
this
}

private[sql] def withNewNotMatchedBySourceAction(action: MergeAction): MergeIntoWriter[T] = {
this.notMatchedBySourceActions = this.notMatchedBySourceActions :+ action
private[sql] def withNewNotMatchedBySourceActions(actions: MergeAction*): MergeIntoWriter[T] = {
this.notMatchedBySourceActions ++= actions
this
}
}
Expand All @@ -234,7 +237,7 @@ case class WhenMatched[T] private[sql](
* @return The MergeIntoWriter instance with the update all action configured.
*/
def updateAll(): MergeIntoWriter[T] = {
mergeIntoWriter.withNewMatchedAction(UpdateStarAction(condition))
mergeIntoWriter.withNewMatchedActions(UpdateStarAction(condition))
}

/**
Expand All @@ -245,7 +248,7 @@ case class WhenMatched[T] private[sql](
* @return The MergeIntoWriter instance with the update action configured.
*/
def update(map: Map[String, Column]): MergeIntoWriter[T] = {
mergeIntoWriter.withNewMatchedAction(
mergeIntoWriter.withNewMatchedActions(
UpdateAction(condition, map.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq))
}

Expand All @@ -255,7 +258,7 @@ case class WhenMatched[T] private[sql](
* @return The MergeIntoWriter instance with the delete action configured.
*/
def delete(): MergeIntoWriter[T] = {
mergeIntoWriter.withNewMatchedAction(DeleteAction(condition))
mergeIntoWriter.withNewMatchedActions(DeleteAction(condition))
}
}

Expand All @@ -281,7 +284,7 @@ case class WhenNotMatched[T] private[sql](
* @return The MergeIntoWriter instance with the insert all action configured.
*/
def insertAll(): MergeIntoWriter[T] = {
mergeIntoWriter.withNewNotMatchedAction(InsertStarAction(condition))
mergeIntoWriter.withNewNotMatchedActions(InsertStarAction(condition))
}

/**
Expand All @@ -292,7 +295,7 @@ case class WhenNotMatched[T] private[sql](
* @return The MergeIntoWriter instance with the insert action configured.
*/
def insert(map: Map[String, Column]): MergeIntoWriter[T] = {
mergeIntoWriter.withNewNotMatchedAction(
mergeIntoWriter.withNewNotMatchedActions(
InsertAction(condition, map.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq))
}
}
Expand All @@ -317,7 +320,7 @@ case class WhenNotMatchedBySource[T] private[sql](
* @return The MergeIntoWriter instance with the update all action configured.
*/
def updateAll(): MergeIntoWriter[T] = {
mergeIntoWriter.withNewNotMatchedBySourceAction(UpdateStarAction(condition))
mergeIntoWriter.withNewNotMatchedBySourceActions(UpdateStarAction(condition))
}

/**
Expand All @@ -328,7 +331,7 @@ case class WhenNotMatchedBySource[T] private[sql](
* @return The MergeIntoWriter instance with the update action configured.
*/
def update(map: Map[String, Column]): MergeIntoWriter[T] = {
mergeIntoWriter.withNewNotMatchedBySourceAction(
mergeIntoWriter.withNewNotMatchedBySourceActions(
UpdateAction(condition, map.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq))
}

Expand All @@ -339,6 +342,6 @@ case class WhenNotMatchedBySource[T] private[sql](
* @return The MergeIntoWriter instance with the delete action configured.
*/
def delete(): MergeIntoWriter[T] = {
mergeIntoWriter.withNewNotMatchedBySourceAction(DeleteAction(condition))
mergeIntoWriter.withNewNotMatchedBySourceActions(DeleteAction(condition))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -943,4 +943,32 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase {
Row(3, Row("y1 ", "y2"), "hr"))) // update (not matched by source)
}
}

test("withSchemaEvolution carries over existing when clauses") {
withTempView("source") {
Seq(1, 2, 4).toDF("pk").createOrReplaceTempView("source")

// an arbitrary merge
val writer1 = spark.table("source")
.mergeInto("dummy", $"col" === $"col")
.whenMatched(col("col") === 1)
.updateAll()
.whenMatched()
.delete()
.whenNotMatched(col("col") === 1)
.insertAll()
.whenNotMatchedBySource(col("col") === 1)
.delete()
val writer2 = writer1.withSchemaEvolution()

assert(writer1.matchedActions.length === 2)
assert(writer1.notMatchedActions.length === 1)
assert(writer1.notMatchedBySourceActions.length === 1)

assert(writer1.matchedActions === writer2.matchedActions)
assert(writer1.notMatchedActions === writer2.notMatchedActions)
assert(writer1.notMatchedBySourceActions === writer2.notMatchedBySourceActions)
assert(writer2.schemaEvolutionEnabled)
}
}
}

0 comments on commit e459674

Please sign in to comment.