Skip to content

Commit

Permalink
[CARMEL-2028] Optimize bucket join in UPDATE (delta-io#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
LantaoJin authored and GitHub Enterprise committed Jan 19, 2020
1 parent 74a5baa commit b84326c
Show file tree
Hide file tree
Showing 11 changed files with 185 additions and 36 deletions.
4 changes: 4 additions & 0 deletions src/main/scala/io/delta/sql/analysis/DeltaSqlResolution.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentDate,
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.delta.DeltaErrors
import org.apache.spark.sql.execution.datasources._

class DeltaSqlResolution(spark: SparkSession) extends Rule[LogicalPlan] {
Expand All @@ -33,6 +34,9 @@ class DeltaSqlResolution(spark: SparkSession) extends Rule[LogicalPlan] {
case u @ UpdateTableStatement(target, assignments, condition, source)
if !u.resolved && target.resolved &&
(if (source.isDefined) source.exists(_.resolved) else true) =>
target.collect {
case View(table, _, _) => throw DeltaErrors.cannotUpdateAViewException(table.identifier)
}
val resolvedAssignments = resolveAssignments(assignments, u)
val columns = resolvedAssignments.map(_.key.asInstanceOf[Attribute])
val values = resolvedAssignments.map(_.value)
Expand Down
4 changes: 4 additions & 0 deletions src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,10 @@ object DeltaErrors
new IllegalArgumentException(
s"Specified mode '$modeName' is not supported. Supported modes are: $supportedModes")
}

def cannotUpdateAViewException(tableIdentifier: TableIdentifier): Throwable = {
new AnalysisException(s"Can not update a View $tableIdentifier.")
}
}

/** The basic class for all Tahoe commit conflict exceptions. */
Expand Down
29 changes: 22 additions & 7 deletions src/main/scala/org/apache/spark/sql/delta/DeltaLog.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.storage.LogStoreProvider
import com.google.common.cache.{CacheBuilder, RemovalListener, RemovalNotification}
import org.apache.hadoop.fs.Path

import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Cast, Expression, In, InSet, Literal}
import org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
Expand Down Expand Up @@ -615,7 +616,7 @@ class DeltaLog private(
fileIndex,
partitionSchema = snapshot.metadata.partitionSchema,
dataSchema = snapshot.metadata.schema,
bucketSpec = None,
bucketSpec = snapshot.metadata.bucketSpec,
snapshot.fileFormat,
snapshot.metadata.format.options)(spark)

Expand All @@ -630,7 +631,9 @@ class DeltaLog private(
*/
def createRelation(
partitionFilters: Seq[Expression] = Nil,
timeTravel: Option[DeltaTimeTravelSpec] = None): BaseRelation = {
timeTravel: Option[DeltaTimeTravelSpec] = None,
table: Option[CatalogTable] = None,
parameters: Map[String, String] = Map.empty): BaseRelation = {

val versionToUse = timeTravel.map { tt =>
val (version, accessType) = DeltaTableUtils.resolveTimeTravelVersion(
Expand All @@ -650,11 +653,23 @@ class DeltaLog private(

new HadoopFsRelation(
fileIndex,
partitionSchema = snapshotToUse.metadata.partitionSchema,
dataSchema = snapshotToUse.metadata.schema,
bucketSpec = None,
partitionSchema = table match {
case Some(t) => t.partitionSchema
case None => snapshotToUse.metadata.partitionSchema
},
dataSchema = table match {
case Some(t) => t.schema
case None => snapshotToUse.metadata.schema
},
bucketSpec = table match {
case Some(t) => t.bucketSpec
case None => None
},
snapshotToUse.fileFormat,
snapshotToUse.metadata.format.options)(spark) with InsertableRelation {
table match {
case Some(t) => t.properties ++ snapshotToUse.metadata.format.options
case None => snapshotToUse.metadata.format.options
})(spark) with InsertableRelation {
def insert(data: DataFrame, overwrite: Boolean): Unit = {
val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append
WriteIntoDelta(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,16 @@ package org.apache.spark.sql.delta

import org.apache.spark.sql.delta.sources.DeltaSourceUtils
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogUtils

/**
* An identifier for a Delta table containing one of the path or the table identifier.
*/
case class DeltaTableIdentifier(
path: Option[String] = None,
table: Option[TableIdentifier] = None) {
assert(path.isDefined ^ table.isDefined, "Please provide one of the path or the table identifier")

val identifier: String = path.getOrElse(table.get.identifier)

Expand Down Expand Up @@ -88,9 +87,22 @@ object DeltaTableIdentifier {
*/
def apply(spark: SparkSession, identifier: TableIdentifier): Option[DeltaTableIdentifier] = {
if (isDeltaPath(spark, identifier)) {
Some(DeltaTableIdentifier(path = Option(identifier.table)))
} else if (DeltaTableUtils.isDeltaTable(spark, identifier)) {
Some(DeltaTableIdentifier(table = Option(identifier)))
return Some(DeltaTableIdentifier(path = Option(identifier.table)))
}
val catalog = spark.sessionState.catalog
val tableIsNotTemporaryTable = !catalog.isTemporaryTable(identifier)
val tableExists =
(identifier.database.isEmpty || catalog.databaseExists(identifier.database.get)) &&
catalog.tableExists(identifier)
if (tableIsNotTemporaryTable && tableExists) {
val catalogTable = catalog.getTableMetadata(identifier)
if (DeltaTableUtils.isDeltaTable(catalogTable)) {
Some(DeltaTableIdentifier(
path = Option(CatalogUtils.URIToString(catalogTable.location)),
table = Option(identifier)))
} else {
None
}
} else {
None
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/org/apache/spark/sql/delta/Snapshot.scala
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class Snapshot(
index,
index.partitionSchema,
logSchema,
None,
None, // todo (lajin) bucketSpec = metadata.bucketSpec?
index.format,
Map.empty[String, String])(spark)
LogicalRelation(fsRelation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import com.fasterxml.jackson.core.JsonGenerator
import com.fasterxml.jackson.databind.{JsonSerializer, SerializerProvider}
import com.fasterxml.jackson.databind.annotation.{JsonDeserialize, JsonSerialize}
import org.codehaus.jackson.annotate.JsonRawValue

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.types.{DataType, StructType}

Expand Down Expand Up @@ -189,7 +189,8 @@ case class Metadata(
partitionColumns: Seq[String] = Nil,
configuration: Map[String, String] = Map.empty,
@JsonDeserialize(contentAs = classOf[java.lang.Long])
createdTime: Option[Long] = Some(System.currentTimeMillis())) extends Action {
createdTime: Option[Long] = Some(System.currentTimeMillis()),
bucketSpec: Option[BucketSpec] = None) extends Action {

// The `schema` and `partitionSchema` methods should be vals or lazy vals, NOT
// defs, because parsing StructTypes from JSON is extremely expensive and has
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,16 @@ import scala.util.control.NonFatal
import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.actions.{AddFile, CommitInfo, Metadata, Protocol}
import org.apache.spark.sql.delta.schema.SchemaUtils
import org.apache.spark.sql.delta.sources.{DeltaSourceUtils, DeltaSQLConf}
import org.apache.spark.sql.delta.sources.{DeltaSQLConf, DeltaSourceUtils}
import org.apache.spark.sql.delta.util.{DateFormatter, DeltaFileOperations, PartitionUtils, TimestampFormatter}
import org.apache.spark.sql.delta.util.FileNames.deltaFile
import org.apache.spark.sql.delta.util.SerializableFileStatus
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}

import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils}
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetToSparkSchemaConverter}
Expand Down Expand Up @@ -104,11 +103,39 @@ abstract class ConvertToDeltaCommandBase(
protected def getConvertProperties(
spark: SparkSession,
tableIdentifier: TableIdentifier): ConvertProperties = {
ConvertProperties(
None,
tableIdentifier.database,
tableIdentifier.table,
Map.empty[String, String])
def convertForPath(tableIdentifier: TableIdentifier): ConvertProperties = {
// convert to delta format.`path`
ConvertProperties(
None,
tableIdentifier.database,
tableIdentifier.table,
Map.empty[String, String])
}

def convertForTable(tableIdentifier: TableIdentifier): ConvertProperties = {
val table = spark.sessionState.catalog.getTableMetadata(tableIdentifier)
ConvertProperties(
Some(table),
table.provider,
CatalogUtils.URIToString(table.location),
table.properties)
}

val identifier = if (tableIdentifier.database.isEmpty) {
tableIdentifier.copy(database = Some(spark.sessionState.catalog.getCurrentDatabase))
} else {
tableIdentifier.copy(
database = tableIdentifier.database.map(_.toLowerCase(Locale.ROOT)))
}

val formats = Seq("delta", "parquet", "orc", "json", "csv", "tsv", "hive", "jdbc")
if (formats.exists(identifier.database.contains(_))) {
// convert to delta format.`path`
convertForPath(identifier)
} else {
// convert to delta table
convertForTable(identifier)
}
}

protected def handleExistingTransactionLog(
Expand Down Expand Up @@ -167,7 +194,8 @@ abstract class ConvertToDeltaCommandBase(
}

val schema = constructTableSchema(spark, dataSchema, partitionFields)
val metadata = Metadata(schemaString = schema.json, partitionColumns = partitionColNames)
val metadata = Metadata(schemaString = schema.json, partitionColumns = partitionColNames,
bucketSpec = convertProperties.catalogTable.flatMap(_.bucketSpec))
txn.updateMetadata(metadata)

val statsBatchSize =
Expand All @@ -177,6 +205,13 @@ abstract class ConvertToDeltaCommandBase(
val adds = batch.map(createAddFile(_, txn.deltaLog.dataPath, fs, spark.sessionState.conf))
adds.toIterator
}

// change provider to delta
if (convertProperties.catalogTable.isDefined) {
val newTable = convertProperties.catalogTable.get.copy(provider = Some("delta"))
spark.sessionState.catalog.alterTable(newTable)
}

streamWrite(
spark,
txn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.functions._
/**
* Performs an update join with a source query/table.
*
* Issues an error message when the ON search_condition of the MERGE statement can match
* Issues an error message when the WHERE search_condition of the UPDATE statement can match
* a single row from the target table with multiple rows of the source table-reference.
*
* Algorithm:
Expand All @@ -46,7 +46,7 @@ import org.apache.spark.sql.functions._
* This is implemented as an inner-join using the given condition. See [[findTouchedFiles]]
* for more details.
*
* Phase 2: Read the touched files again and write new files with updated and/or inserted rows.
* Phase 2: Read the touched files again and write new files with updated rows.
*
* Phase 3: Use the Delta protocol to atomically remove the touched files and add the new files.
*
Expand Down Expand Up @@ -93,10 +93,8 @@ case class UpdateWithJoinCommand(
}
deltaTxn.commit(
deltaActions,
DeltaOperations.Merge(
Option(condition.sql),
updateClause.condition.map(_.sql),
None, None))
DeltaOperations.Update(
Option(condition.sql)))

// Record metrics
val stats = UpdateStats(
Expand Down Expand Up @@ -140,7 +138,7 @@ case class UpdateWithJoinCommand(

// Accumulator to collect all the distinct touched files
val touchedFilesAccum = new SetAccumulator[String]()
spark.sparkContext.register(touchedFilesAccum, "MergeIntoDelta.touchedFiles")
spark.sparkContext.register(touchedFilesAccum, "UpdateWithJoin.touchedFiles")

// UDFs to records touched files names and add them to the accumulator
val recordTouchedFileName = udf { (fileName: String) => {
Expand Down Expand Up @@ -181,7 +179,7 @@ case class UpdateWithJoinCommand(
logTrace(s"findTouchedFiles: matched files:\n\t${touchedFileNames.mkString("\n\t")}")

val nameToAddFileMap = generateCandidateFileMap(targetDeltaLog.dataPath, dataSkippedFiles)
val touchedAddFiles = touchedFileNames.map(f =>
val touchedAddFiles = touchedFileNames.filter(_.nonEmpty).map(f =>
getTouchedFile(targetDeltaLog.dataPath, f, nameToAddFileMap))

metrics("numFilesBeforeSkipping") += deltaTxn.snapshot.numOfFiles
Expand Down Expand Up @@ -216,7 +214,7 @@ case class UpdateWithJoinCommand(
val incrUpdatedCountExpr = makeMetricUpdateUDF("numRowsUpdated")
val incrNoopCountExpr = makeMetricUpdateUDF("numRowsCopied")

// Apply full outer join to find both, matches and non-matches. We are adding two boolean fields
// Apply left outer join to find matches . We are adding two boolean fields
// with value `true`, one to each side of the join. Whether this field is null or not after
// the full outer join, will allow us to identify whether the resultanet joined row was a
// matched inner result or an unmatched result with null on one side.
Expand All @@ -225,7 +223,7 @@ case class UpdateWithJoinCommand(
.withColumn(SOURCE_ROW_PRESENT_COL, new Column(incrSourceRowCountExpr))
val targetDF = Dataset.ofRows(spark, newTarget)
.withColumn(TARGET_ROW_PRESENT_COL, lit(true))
sourceDF.join(targetDF, new Column(condition), "fullOuter")
targetDF.join(sourceDF, new Column(condition), "left")
}

val joinedPlan = joinedDF.queryExecution.analyzed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ trait TransactionalWrite extends DeltaLogging { self: OptimisticTransactionImpl
outputSpec = outputSpec,
hadoopConf = spark.sessionState.newHadoopConfWithOptions(metadata.configuration),
partitionColumns = partitioningColumns,
bucketSpec = None,
bucketSpec = snapshot.metadata.bucketSpec,
statsTrackers = statsTrackers,
options = Map.empty)
}
Expand Down
Loading

0 comments on commit b84326c

Please sign in to comment.