Skip to content

Commit

Permalink
CDF for Merge command
Browse files Browse the repository at this point in the history
See the project plan at delta-io#1105.

This PR adds CDF to the `MERGE` command.

Merge is implemented in two ways.

- Insert-only merges. For these we don't need to do anything special, since we only write `AddFile`s with the new rows.
    - However, our current implementation of insert-only merges doesn't correctly update the metric `numTargetRowsInserted`, which is used to check for data changes in [CDCReader](https://github.com/delta-io/delta/blob/master/core/src/main/scala/org/apache/spark/sql/delta/commands/cdc/CDCReader.scala#L313). This PR fixes that.

- For all other merges, we generate CDF rows for inserts, updates, and deletions. We do this by generating expression sequences for CDF outputs (i.e. preimage, insert, etc) on a clause-by-clause basis. We apply these to the rows in our joinedDF in addition to our existing main data output sequences.
    - Changes made to `JoinedRowProcessor` make column `ROW_DELETED_COL` unnecessary, so this PR removes it.

Tests are added in `MergeCDCSuite`.

Closes delta-io#1155

GitOrigin-RevId: 0386c6ff811abe433644b5f5f46a3c7d51001740
  • Loading branch information
allisonport-db authored and jbguerraz committed Jul 6, 2022
1 parent 8ec17ec commit 7fd7bd2
Show file tree
Hide file tree
Showing 2 changed files with 377 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.actions.{AddFile, FileAction}
import org.apache.spark.sql.delta.actions.{AddCDCFile, AddFile, FileAction}
import org.apache.spark.sql.delta.files._
import org.apache.spark.sql.delta.schema.{ImplicitMetadataOperation, SchemaUtils}
import org.apache.spark.sql.delta.sources.DeltaSQLConf
Expand All @@ -44,7 +44,7 @@ import org.apache.spark.sql.execution.command.LeafRunnableCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types.{DataTypes, StructType}

case class MergeDataSizes(
@JsonDeserialize(contentAs = classOf[java.lang.Long])
Expand Down Expand Up @@ -216,6 +216,7 @@ case class MergeIntoCommand(
import MergeIntoCommand._

import SQLMetrics._
import org.apache.spark.sql.delta.commands.cdc.CDCReader._

override val canMergeSchema: Boolean = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE)
override val canOverwriteSchema: Boolean = false
Expand Down Expand Up @@ -254,6 +255,10 @@ case class MergeIntoCommand(
"numTargetFilesAfterSkipping" -> createMetric(sc, "number of target files after skipping"),
"numTargetFilesRemoved" -> createMetric(sc, "number of files removed to target"),
"numTargetFilesAdded" -> createMetric(sc, "number of files added to target"),
"numTargetChangeFilesAdded" ->
createMetric(sc, "number of change data capture files generated"),
"numTargetChangeFileBytes" ->
createMetric(sc, "total size of change data capture files generated"),
"numTargetBytesBeforeSkipping" -> createMetric(sc, "number of target bytes before skipping"),
"numTargetBytesAfterSkipping" -> createMetric(sc, "number of target bytes after skipping"),
"numTargetBytesRemoved" -> createMetric(sc, "number of target bytes removed"),
Expand Down Expand Up @@ -456,7 +461,7 @@ case class MergeIntoCommand(

val outputColNames = getTargetOutputCols(deltaTxn).map(_.name)
// we use head here since we know there is only a single notMatchedClause
val outputExprs = notMatchedClauses.head.resolvedActions.map(_.expr) :+ incrInsertedCountExpr
val outputExprs = notMatchedClauses.head.resolvedActions.map(_.expr)
val outputCols = outputExprs.zip(outputColNames).map { case (expr, name) =>
new Column(Alias(expr, name)())
}
Expand All @@ -478,6 +483,7 @@ case class MergeIntoCommand(

val insertDf = sourceDF.join(targetDF, new Column(condition), "leftanti")
.select(outputCols: _*)
.filter(new Column(incrInsertedCountExpr))

val newFiles = deltaTxn
.writeFiles(repartitionIfNeeded(spark, insertDf, deltaTxn.metadata.partitionColumns))
Expand All @@ -494,7 +500,7 @@ case class MergeIntoCommand(
metrics("numTargetBytesRemoved") += 0
metrics("numTargetPartitionsRemovedFrom") += 0
val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles)
metrics("numTargetFilesAdded") += newFiles.size
metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile])
metrics("numTargetBytesAdded") += addedBytes
metrics("numTargetPartitionsAddedTo") += addedPartitions
newFiles
Expand All @@ -503,12 +509,18 @@ case class MergeIntoCommand(
/**
* Write new files by reading the touched files and updating/inserting data using the source
* query/table. This is implemented using a full|right-outer-join using the merge condition.
*
* Note that unlike the insert-only code paths with just one control column INCR_ROW_COUNT_COL,
* this method has two additional control columns ROW_DROPPED_COL for dropping deleted rows and
* CDC_TYPE_COL_NAME used for handling CDC when enabled.
*/
private def writeAllChanges(
spark: SparkSession,
deltaTxn: OptimisticTransaction,
filesToRewrite: Seq[AddFile]
): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") {
import org.apache.spark.sql.catalyst.expressions.Literal.{TrueLiteral, FalseLiteral}

val targetOutputCols = getTargetOutputCols(deltaTxn)

// Generate a new logical plan that has same output attributes exprIds as the target plan.
Expand Down Expand Up @@ -548,36 +560,108 @@ case class MergeIntoCommand(
}

val joinedPlan = joinedDF.queryExecution.analyzed
val cdcEnabled = DeltaConfigs.CHANGE_DATA_FEED.fromMetaData(deltaTxn.metadata)

def resolveOnJoinedPlan(exprs: Seq[Expression]): Seq[Expression] = {
tryResolveReferencesForExpressions(spark, exprs, joinedPlan)
}

def matchedClauseOutput(clause: DeltaMergeIntoMatchedClause): Seq[Expression] = {
// ==== Generate the expressions to process full-outer join output and generate target rows ====
// If there are N columns in the target table, there will be N + 3 columns after processing
// - N columns for target table
// - ROW_DROPPED_COL to define whether the generated row should dropped or written
// - INCR_ROW_COUNT_COL containing a UDF to update the output row row counter
// - CDC_TYPE_COLUMN_NAME containing the type of change being performed in a particular row

// To generate these N + 3 columns, we will generate N + 3 expressions and apply them to the
// rows in the joinedDF. The CDC column will be either used for CDC generation or dropped before
// performing the final write, and the other two will always be dropped after executing the
// metrics UDF and filtering on ROW_DROPPED_COL.

// We produce rows for both the main table data (with CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC),
// and rows for the CDC data which will be output to CDCReader.CDC_LOCATION.
// See [[CDCReader]] for general details on how partitioning on the CDC type column works.

// In the following two functions `matchedClauseOutput` and `notMatchedClauseOutput`, we
// produce a Seq[Expression] for each intended output row.
// Depending on the clause and whether CDC is enabled, we output between 0 and 3 rows, as a
// Seq[Seq[Expression]]

def matchedClauseOutput(clause: DeltaMergeIntoMatchedClause): Seq[Seq[Expression]] = {
val exprs = clause match {
case u: DeltaMergeIntoUpdateClause =>
// Generate update expressions and set ROW_DELETED_COL = false
u.resolvedActions.map(_.expr) :+ Literal.FalseLiteral :+ incrUpdatedCountExpr
// Generate update expressions and set ROW_DELETED_COL = false and
// CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC
val mainDataOutput = u.resolvedActions.map(_.expr) :+ FalseLiteral :+
incrUpdatedCountExpr :+ Literal(CDC_TYPE_NOT_CDC)
if (cdcEnabled) {
// For update preimage, we have do a no-op copy with ROW_DELETED_COL = false and
// CDC_TYPE_COLUMN_NAME = CDC_TYPE_UPDATE_PREIMAGE and INCR_ROW_COUNT_COL as a no-op
// (because the metric will be incremented in `mainDataOutput`)
val preImageOutput = targetOutputCols :+ FalseLiteral :+ TrueLiteral :+
Literal(CDC_TYPE_UPDATE_PREIMAGE)
// For update postimage, we have the same expressions as for mainDataOutput but with
// INCR_ROW_COUNT_COL as a no-op (because the metric will be incremented in
// `mainDataOutput`), and CDC_TYPE_COLUMN_NAME = CDC_TYPE_UPDATE_POSTIMAGE
val postImageOutput = mainDataOutput.dropRight(2) :+ TrueLiteral :+
Literal(CDC_TYPE_UPDATE_POSTIMAGE)
Seq(mainDataOutput, preImageOutput, postImageOutput)
} else {
Seq(mainDataOutput)
}
case _: DeltaMergeIntoDeleteClause =>
// Generate expressions to set the ROW_DELETED_COL = true
targetOutputCols :+ Literal.TrueLiteral :+ incrDeletedCountExpr
// Generate expressions to set the ROW_DELETED_COL = true and CDC_TYPE_COLUMN_NAME =
// CDC_TYPE_NOT_CDC
val mainDataOutput = targetOutputCols :+ TrueLiteral :+ incrDeletedCountExpr :+
Literal(CDC_TYPE_NOT_CDC)
if (cdcEnabled) {
// For delete we do a no-op copy with ROW_DELETED_COL = false, INCR_ROW_COUNT_COL as a
// no-op (because the metric will be incremented in `mainDataOutput`) and
// CDC_TYPE_COLUMN_NAME = CDC_TYPE_DELETE
val deleteCdcOutput = targetOutputCols :+ FalseLiteral :+ TrueLiteral :+
Literal(CDC_TYPE_DELETE)
Seq(mainDataOutput, deleteCdcOutput)
} else {
Seq(mainDataOutput)
}
}
resolveOnJoinedPlan(exprs)
exprs.map(resolveOnJoinedPlan)
}

def notMatchedClauseOutput(clause: DeltaMergeIntoInsertClause): Seq[Expression] = {
resolveOnJoinedPlan(
clause.resolvedActions.map(_.expr) :+ Literal.FalseLiteral :+ incrInsertedCountExpr)
def notMatchedClauseOutput(clause: DeltaMergeIntoInsertClause): Seq[Seq[Expression]] = {
// Generate insert expressions and set ROW_DELETED_COL = false and
// CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC
val mainDataOutput = resolveOnJoinedPlan(
clause.resolvedActions.map(_.expr) :+ FalseLiteral :+ incrInsertedCountExpr :+
Literal(CDC_TYPE_NOT_CDC))
if (cdcEnabled) {
// For insert we have the same expressions as for mainDataOutput, but with
// INCR_ROW_COUNT_COL as a no-op (because the metric will be incremented in
// `mainDataOutput`), and CDC_TYPE_COLUMN_NAME = CDC_TYPE_INSERT
val insertCdcOutput = mainDataOutput.dropRight(2) :+ TrueLiteral :+ Literal(CDC_TYPE_INSERT)
Seq(mainDataOutput, insertCdcOutput)
} else {
Seq(mainDataOutput)
}
}

def clauseCondition(clause: DeltaMergeIntoClause): Expression = {
// if condition is None, then expression always evaluates to true
val condExpr = clause.condition.getOrElse(Literal.TrueLiteral)
val condExpr = clause.condition.getOrElse(TrueLiteral)
resolveOnJoinedPlan(Seq(condExpr)).head
}

val outputRowSchema = if (!cdcEnabled) {
deltaTxn.metadata.schema
} else {
deltaTxn.metadata.schema
.add(ROW_DROPPED_COL, DataTypes.BooleanType)
.add(INCR_ROW_COUNT_COL, DataTypes.BooleanType)
.add(CDC_TYPE_COLUMN_NAME, DataTypes.StringType)
}

val joinedRowEncoder = RowEncoder(joinedPlan.schema)
val outputRowEncoder = RowEncoder(deltaTxn.metadata.schema).resolveAndBind()
val outputRowEncoder = RowEncoder(outputRowSchema).resolveAndBind()

val processor = new JoinedRowProcessor(
targetRowHasNoMatch = resolveOnJoinedPlan(Seq(col(SOURCE_ROW_PRESENT_COL).isNull.expr)).head,
Expand All @@ -587,15 +671,19 @@ case class MergeIntoCommand(
notMatchedConditions = notMatchedClauses.map(clauseCondition),
notMatchedOutputs = notMatchedClauses.map(notMatchedClauseOutput),
noopCopyOutput =
resolveOnJoinedPlan(targetOutputCols :+ Literal.FalseLiteral :+ incrNoopCountExpr),
resolveOnJoinedPlan(targetOutputCols :+ FalseLiteral :+ incrNoopCountExpr :+
Literal(CDC_TYPE_NOT_CDC)),
deleteRowOutput =
resolveOnJoinedPlan(targetOutputCols :+ Literal.TrueLiteral :+ Literal.TrueLiteral),
resolveOnJoinedPlan(targetOutputCols :+ TrueLiteral :+ TrueLiteral :+
Literal(CDC_TYPE_NOT_CDC)),
joinedAttributes = joinedPlan.output,
joinedRowEncoder = joinedRowEncoder,
outputRowEncoder = outputRowEncoder)

val outputDF =
Dataset.ofRows(spark, joinedPlan).mapPartitions(processor.processPartition)(outputRowEncoder)
.drop(ROW_DROPPED_COL, INCR_ROW_COUNT_COL)

logDebug("writeAllChanges: join output plan:\n" + outputDF.queryExecution)

// Write to Delta
Expand All @@ -604,7 +692,9 @@ case class MergeIntoCommand(

// Update metrics
val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles)
metrics("numTargetFilesAdded") += newFiles.size
metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile])
metrics("numTargetChangeFilesAdded") += newFiles.count(_.isInstanceOf[AddCDCFile])
metrics("numTargetChangeFileBytes") += newFiles.collect{ case f: AddCDCFile => f.size }.sum
metrics("numTargetBytesAdded") += addedBytes
metrics("numTargetPartitionsAddedTo") += addedPartitions

Expand Down Expand Up @@ -747,14 +837,36 @@ object MergeIntoCommand {
val FILE_NAME_COL = "_file_name_"
val SOURCE_ROW_PRESENT_COL = "_source_row_present_"
val TARGET_ROW_PRESENT_COL = "_target_row_present_"
val ROW_DROPPED_COL = "_row_dropped_"
val INCR_ROW_COUNT_COL = "_incr_row_count_"

/**
* @param targetRowHasNoMatch whether a joined row is a target row with no match in the source
* table
* @param sourceRowHasNoMatch whether a joined row is a source row with no match in the target
* table
* @param matchedConditions condition for each match clause
* @param matchedOutputs corresponding output for each match clause. for each clause, we
* have 1-3 output rows, each of which is a sequence of expressions
* to apply to the joined row
* @param notMatchedConditions condition for each not-matched clause
* @param notMatchedOutputs corresponding output for each not-matched clause. for each clause,
* we have 1-2 output rows, each of which is a sequence of
* expressions to apply to the joined row
* @param noopCopyOutput no-op expression to copy a target row to the output
* @param deleteRowOutput expression to drop a row from the final output. this is used for
* source rows that don't match any not-matched clauses
* @param joinedAttributes schema of our outer-joined dataframe
* @param joinedRowEncoder joinedDF row encoder
* @param outputRowEncoder final output row encoder
*/
class JoinedRowProcessor(
targetRowHasNoMatch: Expression,
sourceRowHasNoMatch: Expression,
matchedConditions: Seq[Expression],
matchedOutputs: Seq[Seq[Expression]],
matchedOutputs: Seq[Seq[Seq[Expression]]],
notMatchedConditions: Seq[Expression],
notMatchedOutputs: Seq[Seq[Expression]],
notMatchedOutputs: Seq[Seq[Seq[Expression]]],
noopCopyOutput: Seq[Expression],
deleteRowOutput: Seq[Expression],
joinedAttributes: Seq[Attribute],
Expand All @@ -774,20 +886,26 @@ object MergeIntoCommand {
val targetRowHasNoMatchPred = generatePredicate(targetRowHasNoMatch)
val sourceRowHasNoMatchPred = generatePredicate(sourceRowHasNoMatch)
val matchedPreds = matchedConditions.map(generatePredicate)
val matchedProjs = matchedOutputs.map(generateProjection)
val matchedProjs = matchedOutputs.map(_.map(generateProjection))
val notMatchedPreds = notMatchedConditions.map(generatePredicate)
val notMatchedProjs = notMatchedOutputs.map(generateProjection)
val notMatchedProjs = notMatchedOutputs.map(_.map(generateProjection))
val noopCopyProj = generateProjection(noopCopyOutput)
val deleteRowProj = generateProjection(deleteRowOutput)
val outputProj = UnsafeProjection.create(outputRowEncoder.schema)

def shouldDeleteRow(row: InternalRow): Boolean =
row.getBoolean(outputRowEncoder.schema.fields.size)
// this is accessing ROW_DROPPED_COL. If ROW_DROPPED_COL is not in outputRowEncoder.schema
// then CDC must be disabled and it's the column after our output cols
def shouldDeleteRow(row: InternalRow): Boolean = {
row.getBoolean(
outputRowEncoder.schema.getFieldIndex(ROW_DROPPED_COL)
.getOrElse(outputRowEncoder.schema.fields.size)
)
}

def processRow(inputRow: InternalRow): InternalRow = {
def processRow(inputRow: InternalRow): Iterator[InternalRow] = {
if (targetRowHasNoMatchPred.eval(inputRow)) {
// Target row did not match any source row, so just copy it to the output
noopCopyProj.apply(inputRow)
Iterator(noopCopyProj.apply(inputRow))
} else {
// identify which set of clauses to execute: matched or not-matched ones
val (predicates, projections, noopAction) = if (sourceRowHasNoMatchPred.eval(inputRow)) {
Expand All @@ -804,8 +922,9 @@ object MergeIntoCommand {
}

pair match {
case Some((_, projections)) => projections.apply(inputRow)
case None => noopAction.apply(inputRow)
case Some((_, projections)) =>
projections.map(_.apply(inputRow)).iterator
case None => Iterator(noopAction.apply(inputRow))
}
}
}
Expand All @@ -814,7 +933,7 @@ object MergeIntoCommand {
val fromRow = outputRowEncoder.createDeserializer()
rowIterator
.map(toRow)
.map(processRow)
.flatMap(processRow)
.filter(!shouldDeleteRow(_))
.map { notDeletedInternalRow =>
fromRow(outputProj(notDeletedInternalRow))
Expand Down
Loading

0 comments on commit 7fd7bd2

Please sign in to comment.